package com.alibaba.alink.operator.common.tree.paralleltree;

import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.operator.common.tree.Node;
import java.util.Arrays;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/paralleltree/ClassifierObj.class */
public class ClassifierObj extends TreeObj<int[]> {
    private static final Logger LOG = LoggerFactory.getLogger(ClassifierObj.class);
    private static final long serialVersionUID = -913607375511779459L;
    private Criteria.Gini total;
    private int[] labels;
    private int nLabels;

    public ClassifierObj(Params params, QuantileDiscretizerModelDataConverter quantileDiscretizerModelDataConverter, FeatureMeta[] featureMetaArr, FeatureMeta featureMeta) {
        super(params, quantileDiscretizerModelDataConverter, featureMetaArr, featureMeta);
        this.nLabels = this.labelMeta.getNumCategorical();
        this.total = new Criteria.Gini(Criteria.INVALID_GAIN, 0, new double[this.nLabels]);
        initBuffer();
    }

    @Override // com.alibaba.alink.operator.common.tree.paralleltree.TreeObj
    public int lenPerStat() {
        return baggingFeatureCount() != this.nFeatureCol ? this.nLabels * baggingFeatureCount() * this.nBin : this.nLabels * this.nFeatureCol * this.nBin;
    }

    @Override // com.alibaba.alink.operator.common.tree.paralleltree.TreeObj
    public void setLabels(int[] iArr) {
        this.labels = iArr;
    }

    @Override // com.alibaba.alink.operator.common.tree.paralleltree.TreeObj
    public void stat(int i, int i2, int i3, int i4, int i5) {
        Arrays.fill(this.hist, i4, i5, Criteria.INVALID_GAIN);
        if (baggingFeatureCount() == this.nFeatureCol) {
            for (int i6 = 0; i6 < this.nFeatureCol; i6++) {
                int i7 = i6 * this.numLocalRow;
                int i8 = (i6 * this.nLabels * this.nBin) + i4;
                for (int i9 = i2; i9 < i3; i9++) {
                    int i10 = this.partitions[i9];
                    int i11 = i8 + (this.features[i7 + i10] * this.nLabels);
                    double[] dArr = this.hist;
                    int i12 = i11 + this.labels[i10];
                    dArr[i12] = dArr[i12] + 1.0d;
                }
            }
            return;
        }
        int[] iArr = this.loopBuffer[i].baggingFeatures;
        for (int i13 = 0; i13 < baggingFeatureCount(); i13++) {
            int i14 = iArr[i13] * this.numLocalRow;
            int i15 = (i13 * this.nLabels * this.nBin) + i4;
            for (int i16 = i2; i16 < i3; i16++) {
                int i17 = this.partitions[i16];
                int i18 = i15 + (this.features[i14 + i17] * this.nLabels);
                double[] dArr2 = this.hist;
                int i19 = i18 + this.labels[i17];
                dArr2[i19] = dArr2[i19] + 1.0d;
            }
        }
    }

    public final Tuple2<int[], Double> bestSplitCategorical(int i, int i2) {
        double d = 0.0d;
        int[] iArr = new int[i2];
        int[] iArr2 = new int[i2];
        int i3 = 1 << (i2 - 1);
        for (int i4 = 1; i4 < i3; i4++) {
            Criteria.Gini gini = new Criteria.Gini(Criteria.INVALID_GAIN, 0, new double[this.nLabels]);
            Criteria.Gini gini2 = new Criteria.Gini(Criteria.INVALID_GAIN, 0, new double[this.nLabels]);
            for (int i5 = 0; i5 < i2; i5++) {
                iArr2[i5] = -1;
                if ((i4 & (1 << i5)) != 0) {
                    iArr2[i5] = 0;
                    int i6 = i + (i5 * this.nLabels);
                    for (int i7 = 0; i7 < this.nLabels; i7++) {
                        gini.add(i7, this.minusHist[i6 + i7], 1);
                    }
                } else {
                    iArr2[i5] = 1;
                    int i8 = i + (i5 * this.nLabels);
                    for (int i9 = 0; i9 < this.nLabels; i9++) {
                        gini2.add(i9, this.minusHist[i8 + i9], 1);
                    }
                }
            }
            if (this.minSamplesPerLeaf <= gini.getNumInstances() && this.minSamplesPerLeaf <= gini2.getNumInstances()) {
                double gain = this.total.gain(gini, gini2);
                if (gain > d) {
                    d = gain;
                    System.arraycopy(iArr2, 0, iArr, 0, i2);
                }
            }
        }
        return Tuple2.of(iArr, Double.valueOf(d));
    }

    public final Tuple2<Integer, Double> bestSplitNumerical(int i) throws Exception {
        double d = 0.0d;
        int i2 = 0;
        Criteria.Gini gini = new Criteria.Gini(Criteria.INVALID_GAIN, 0, new double[this.nLabels]);
        Criteria.Gini gini2 = (Criteria.Gini) this.total.mo606clone();
        for (int i3 = 0; i3 < this.nBin - 1; i3++) {
            int i4 = i + (i3 * this.nLabels);
            for (int i5 = 0; i5 < this.nLabels; i5++) {
                gini.add(i5, this.minusHist[i4 + i5], 1);
                gini2.subtract(i5, this.minusHist[i4 + i5], 1);
            }
            if (this.minSamplesPerLeaf <= gini.getNumInstances() && this.minSamplesPerLeaf <= gini2.getNumInstances()) {
                double gain = this.total.gain(gini, gini2);
                if (gain > d) {
                    d = gain;
                    i2 = i3;
                }
            }
        }
        return Tuple2.of(Integer.valueOf(i2), Double.valueOf(d));
    }

    @Override // com.alibaba.alink.operator.common.tree.paralleltree.TreeObj
    public final void bestSplit(Node node, int i, NodeInfoPair nodeInfoPair) throws Exception {
        int lenPerStat = i * lenPerStat();
        double d = 0.0d;
        int i2 = 0;
        int[] iArr = null;
        int i3 = -1;
        this.total = new Criteria.Gini(Criteria.INVALID_GAIN, 0, new double[this.nLabels]);
        for (int i4 = 0; i4 < this.nBin; i4++) {
            int i5 = lenPerStat + (i4 * this.nLabels);
            for (int i6 = 0; i6 < this.nLabels; i6++) {
                this.total.add(i6, this.minusHist[i5 + i6], 1);
            }
        }
        node.setCounter(this.total.toLabelCounter());
        if (this.total.getWeightSum() < Criteria.INVALID_GAIN) {
            LOG.info("total: {}", JsonConverter.gson.toJson(this.total));
        }
        if (this.maxDepth < nodeInfoPair.depth || this.minSamplesPerLeaf * 2 >= this.total.getNumInstances()) {
            node.makeLeaf();
            return;
        }
        if (baggingFeatureCount() != this.nFeatureCol) {
            int baggingFeatureCount = baggingFeatureCount();
            for (int i7 = 0; i7 < baggingFeatureCount; i7++) {
                int i8 = lenPerStat + (i7 * this.nBin * this.nLabels);
                if (this.featureMetas[nodeInfoPair.baggingFeatures[i7]].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
                    Tuple2<Integer, Double> bestSplitNumerical = bestSplitNumerical(i8);
                    if (((Double) bestSplitNumerical.f1).doubleValue() > d) {
                        d = ((Double) bestSplitNumerical.f1).doubleValue();
                        i2 = ((Integer) bestSplitNumerical.f0).intValue();
                        i3 = nodeInfoPair.baggingFeatures[i7];
                    }
                } else {
                    Tuple2<int[], Double> bestSplitCategorical = bestSplitCategorical(i8, this.featureMetas[nodeInfoPair.baggingFeatures[i7]].getNumCategorical());
                    if (((Double) bestSplitCategorical.f1).doubleValue() > d) {
                        d = ((Double) bestSplitCategorical.f1).doubleValue();
                        iArr = (int[]) bestSplitCategorical.f0;
                        i3 = nodeInfoPair.baggingFeatures[i7];
                    }
                }
            }
        } else {
            for (int i9 = 0; i9 < this.nFeatureCol; i9++) {
                int i10 = lenPerStat + (i9 * this.nBin * this.nLabels);
                if (this.featureMetas[i9].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
                    Tuple2<Integer, Double> bestSplitNumerical2 = bestSplitNumerical(i10);
                    if (((Double) bestSplitNumerical2.f1).doubleValue() > d) {
                        d = ((Double) bestSplitNumerical2.f1).doubleValue();
                        i2 = ((Integer) bestSplitNumerical2.f0).intValue();
                        i3 = i9;
                    }
                } else {
                    Tuple2<int[], Double> bestSplitCategorical2 = bestSplitCategorical(i10, this.featureMetas[i9].getNumCategorical());
                    if (((Double) bestSplitCategorical2.f1).doubleValue() > d) {
                        d = ((Double) bestSplitCategorical2.f1).doubleValue();
                        iArr = (int[]) bestSplitCategorical2.f0;
                        i3 = i9;
                    }
                }
            }
        }
        if (d <= Criteria.INVALID_GAIN) {
            node.makeLeaf();
            return;
        }
        node.setFeatureIndex(i3);
        if (this.featureMetas[i3].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
            node.setContinuousSplit(i2);
        } else {
            node.setCategoricalSplit(iArr);
        }
        node.setGain(d);
    }
}
