package com.alibaba.alink.operator.common.tree.paralleltree;

import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.operator.common.tree.LabelCounter;
import com.alibaba.alink.operator.common.tree.Node;
import java.util.Arrays;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/paralleltree/RegObj.class */
public class RegObj extends TreeObj<double[]> {
    private static final long serialVersionUID = 2055187633851169618L;
    private double[] labels;

    public RegObj(Params params, QuantileDiscretizerModelDataConverter quantileDiscretizerModelDataConverter, FeatureMeta[] featureMetaArr, FeatureMeta featureMeta) {
        super(params, quantileDiscretizerModelDataConverter, featureMetaArr, featureMeta);
        initBuffer();
    }

    @Override // com.alibaba.alink.operator.common.tree.paralleltree.TreeObj
    public void stat(int i, int i2, int i3, int i4, int i5) {
        Arrays.fill(this.hist, i4, i5, Criteria.INVALID_GAIN);
        if (baggingFeatureCount() == this.nFeatureCol) {
            for (int i6 = 0; i6 < this.nFeatureCol; i6++) {
                int i7 = i6 * this.numLocalRow;
                int i8 = (i6 * 3 * this.nBin) + i4;
                for (int i9 = i2; i9 < i3; i9++) {
                    int i10 = this.partitions[i9];
                    int i11 = i8 + (this.features[i7 + i10] * 3);
                    double d = this.labels[i10];
                    double[] dArr = this.hist;
                    dArr[i11] = dArr[i11] + d;
                    double[] dArr2 = this.hist;
                    int i12 = i11 + 1;
                    dArr2[i12] = dArr2[i12] + (d * d);
                    double[] dArr3 = this.hist;
                    int i13 = i11 + 2;
                    dArr3[i13] = dArr3[i13] + 1.0d;
                }
            }
            return;
        }
        int[] iArr = this.loopBuffer[i].baggingFeatures;
        for (int i14 = 0; i14 < baggingFeatureCount(); i14++) {
            int i15 = iArr[i14] * this.numLocalRow;
            int i16 = (i14 * 3 * this.nBin) + i4;
            for (int i17 = i2; i17 < i3; i17++) {
                int i18 = this.partitions[i17];
                int i19 = i16 + (this.features[i15 + i18] * 3);
                double d2 = this.labels[i18];
                double[] dArr4 = this.hist;
                dArr4[i19] = dArr4[i19] + d2;
                double[] dArr5 = this.hist;
                int i20 = i19 + 1;
                dArr5[i20] = dArr5[i20] + (d2 * d2);
                double[] dArr6 = this.hist;
                int i21 = i19 + 2;
                dArr6[i21] = dArr6[i21] + 1.0d;
            }
        }
    }

    @Override // com.alibaba.alink.operator.common.tree.paralleltree.TreeObj
    public final int lenPerStat() {
        return baggingFeatureCount() != this.nFeatureCol ? baggingFeatureCount() * 3 * this.nBin : this.nFeatureCol * 3 * this.nBin;
    }

    @Override // com.alibaba.alink.operator.common.tree.paralleltree.TreeObj
    public void setLabels(double[] dArr) {
        this.labels = dArr;
    }

    public final Tuple2<int[], Double> bestSplitCategorical(int i, double d, double d2, double d3, double d4, int i2) {
        double d5 = 0.0d;
        int[] iArr = new int[i2];
        int[] iArr2 = new int[i2];
        int i3 = 1 << (i2 - 1);
        for (int i4 = 1; i4 < i3; i4++) {
            double d6 = 0.0d;
            double d7 = 0.0d;
            double d8 = 0.0d;
            double d9 = 0.0d;
            double d10 = 0.0d;
            double d11 = 0.0d;
            for (int i5 = 0; i5 < i2; i5++) {
                iArr2[i5] = -1;
                if ((i4 & (1 << i5)) != 0) {
                    iArr2[i5] = 0;
                    int i6 = i + (i5 * 3);
                    d6 += this.minusHist[i6];
                    d7 += this.minusHist[i6 + 1];
                    d8 += this.minusHist[i6 + 2];
                } else {
                    iArr2[i5] = 1;
                    int i7 = i + (i5 * 3);
                    d9 += this.minusHist[i7];
                    d10 += this.minusHist[i7 + 1];
                    d11 += this.minusHist[i7 + 2];
                }
            }
            if (this.minSamplesPerLeaf <= d8 && this.minSamplesPerLeaf <= d11) {
                double d12 = d6 / d8;
                double d13 = (d7 / d8) - (d12 * d12);
                double d14 = d9 / d11;
                double d15 = (d - ((d8 / d4) * d13)) - ((d11 / d4) * ((d10 / d11) - (d14 * d14)));
                if (d15 > d5) {
                    d5 = d15;
                    System.arraycopy(iArr2, 0, iArr, 0, i2);
                }
            }
        }
        return Tuple2.of(iArr, Double.valueOf(d5));
    }

    public final Tuple2<Integer, Double> bestSplitNumerical(int i, double d, double d2, double d3, double d4) {
        double d5 = 0.0d;
        int i2 = 0;
        double d6 = 0.0d;
        double d7 = 0.0d;
        double d8 = 0.0d;
        double d9 = d3;
        double d10 = d2;
        double d11 = d4;
        for (int i3 = 0; i3 < this.nBin - 1; i3++) {
            int i4 = i + (i3 * 3);
            d6 += this.minusHist[i4];
            d7 += this.minusHist[i4 + 1];
            d8 += this.minusHist[i4 + 2];
            d9 -= this.minusHist[i4];
            d10 -= this.minusHist[i4 + 1];
            d11 -= this.minusHist[i4 + 2];
            if (this.minSamplesPerLeaf <= d8 && this.minSamplesPerLeaf <= d11) {
                double d12 = d6 / d8;
                double d13 = (d7 / d8) - (d12 * d12);
                double d14 = d9 / d11;
                double d15 = (d - ((d8 / d4) * d13)) - ((d11 / d4) * ((d10 / d11) - (d14 * d14)));
                if (d15 > d5) {
                    d5 = d15;
                    i2 = i3;
                }
            }
        }
        return Tuple2.of(Integer.valueOf(i2), Double.valueOf(d5));
    }

    @Override // com.alibaba.alink.operator.common.tree.paralleltree.TreeObj
    public final void bestSplit(Node node, int i, NodeInfoPair nodeInfoPair) {
        int lenPerStat = i * lenPerStat();
        double d = 0.0d;
        int i2 = 0;
        int[] iArr = null;
        int i3 = -1;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i4 = 0; i4 < this.nBin; i4++) {
            int i5 = lenPerStat + (i4 * 3);
            d2 += this.minusHist[i5];
            d3 += this.minusHist[i5 + 1];
            d4 += this.minusHist[i5 + 2];
        }
        node.setCounter(new LabelCounter(d4, 0, new double[]{d2, d3}));
        if (this.maxDepth < nodeInfoPair.depth || this.minSamplesPerLeaf > d4) {
            node.makeLeaf();
            return;
        }
        double d5 = d2 / d4;
        double d6 = (d3 / d4) - (d5 * d5);
        if (baggingFeatureCount() != this.nFeatureCol) {
            int baggingFeatureCount = baggingFeatureCount();
            for (int i6 = 0; i6 < baggingFeatureCount; i6++) {
                int i7 = lenPerStat + (i6 * this.nBin * 3);
                if (this.featureMetas[nodeInfoPair.baggingFeatures[i6]].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
                    Tuple2<Integer, Double> bestSplitNumerical = bestSplitNumerical(i7, d6, d3, d2, d4);
                    if (((Double) bestSplitNumerical.f1).doubleValue() > d) {
                        d = ((Double) bestSplitNumerical.f1).doubleValue();
                        i2 = ((Integer) bestSplitNumerical.f0).intValue();
                        i3 = nodeInfoPair.baggingFeatures[i6];
                    }
                } else {
                    Tuple2<int[], Double> bestSplitCategorical = bestSplitCategorical(i7, d6, d3, d2, d4, this.featureMetas[nodeInfoPair.baggingFeatures[i6]].getNumCategorical());
                    if (((Double) bestSplitCategorical.f1).doubleValue() > d) {
                        d = ((Double) bestSplitCategorical.f1).doubleValue();
                        iArr = (int[]) bestSplitCategorical.f0;
                        i3 = nodeInfoPair.baggingFeatures[i6];
                    }
                }
            }
        } else {
            for (int i8 = 0; i8 < this.nFeatureCol; i8++) {
                int i9 = lenPerStat + (i8 * this.nBin * 3);
                if (this.featureMetas[i8].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
                    Tuple2<Integer, Double> bestSplitNumerical2 = bestSplitNumerical(i9, d6, d3, d2, d4);
                    if (((Double) bestSplitNumerical2.f1).doubleValue() > d) {
                        d = ((Double) bestSplitNumerical2.f1).doubleValue();
                        i2 = ((Integer) bestSplitNumerical2.f0).intValue();
                        i3 = i8;
                    }
                } else {
                    Tuple2<int[], Double> bestSplitCategorical2 = bestSplitCategorical(i9, d6, d3, d2, d4, this.featureMetas[i8].getNumCategorical());
                    if (((Double) bestSplitCategorical2.f1).doubleValue() > d) {
                        d = ((Double) bestSplitCategorical2.f1).doubleValue();
                        iArr = (int[]) bestSplitCategorical2.f0;
                        i3 = i8;
                    }
                }
            }
        }
        if (d <= Criteria.INVALID_GAIN) {
            node.makeLeaf();
            return;
        }
        node.setFeatureIndex(i3);
        if (this.featureMetas[i3].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
            node.setContinuousSplit(i2);
        } else {
            node.setCategoricalSplit(iArr);
        }
        node.setGain(d);
    }
}
