package com.alibaba.alink.operator.common.classification.ann;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;

/* loaded from: input_file:com/alibaba/alink/operator/common/classification/ann/Stacker.class */
public class Stacker implements Serializable {
    public static final int BATCH_SIZE = 64;
    private static final long serialVersionUID = -6416234078414747788L;
    private final int inputSize;
    private final int outputSize;
    private final boolean onehot;
    private transient DenseMatrix features;
    private transient DenseMatrix labels;

    public Stacker(int i, int i2, boolean z) {
        this.inputSize = i;
        this.outputSize = i2;
        this.onehot = z;
    }

    public Tuple3<Double, Double, Vector> stack(List<Tuple2<Double, DenseVector>> list, int i) {
        DenseVector denseVector = new DenseVector((this.inputSize * i) + i);
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            System.arraycopy(((DenseVector) list.get(i3).f1).getData(), 0, denseVector.getData(), i2, this.inputSize);
            i2 += this.inputSize;
        }
        for (int i4 = 0; i4 < i; i4++) {
            denseVector.set(i2, ((Double) list.get(i4).f0).doubleValue());
            i2++;
        }
        return Tuple3.of(Double.valueOf(i), Double.valueOf(Criteria.INVALID_GAIN), denseVector);
    }

    public Tuple2<DenseMatrix, DenseMatrix> unstack(Tuple3<Double, Double, Vector> tuple3) {
        int intValue = ((Double) tuple3.f0).intValue();
        DenseVector denseVector = (DenseVector) tuple3.f2;
        if (this.features == null || this.features.numRows() != intValue) {
            this.features = new DenseMatrix(intValue, this.inputSize);
        }
        int i = 0;
        for (int i2 = 0; i2 < intValue; i2++) {
            for (int i3 = 0; i3 < this.inputSize; i3++) {
                this.features.set(i2, i3, denseVector.get(i));
                i++;
            }
        }
        if (this.labels == null || this.labels.numRows() != intValue) {
            this.labels = new DenseMatrix(intValue, this.onehot ? this.outputSize : 1);
        }
        if (this.onehot) {
            Arrays.fill(this.labels.getData(), Criteria.INVALID_GAIN);
            int i4 = intValue * this.inputSize;
            for (int i5 = 0; i5 < intValue; i5++) {
                int i6 = (int) denseVector.get(i4 + i5);
                if (i6 < 0 || i6 >= this.outputSize) {
                    throw new AkIllegalDataException("Invalid target value: " + i6);
                }
                this.labels.set(i5, i6, 1.0d);
            }
        } else {
            int i7 = intValue * this.inputSize;
            for (int i8 = 0; i8 < intValue; i8++) {
                this.labels.set(i8, 0, denseVector.get(i7 + i8));
            }
        }
        return Tuple2.of(this.features, this.labels);
    }
}
