package com.alibaba.alink.operator.common.tree.seriestree;

import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.operator.common.tree.FeatureSplitter;
import com.alibaba.alink.operator.common.tree.LabelAccessor;
import com.alibaba.alink.operator.common.tree.Node;
import com.alibaba.alink.params.shared.tree.HasMaxDepth;
import com.alibaba.alink.params.shared.tree.HasMinSamplesPerLeaf;
import java.util.ArrayList;
import java.util.Comparator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/seriestree/CategoricalSplitter.class */
public class CategoricalSplitter extends SequentialFeatureSplitter {
    private int[] splitPoint;
    private Criteria[] best;
    private Criteria[] categoricalCriteria;

    public CategoricalSplitter(Params params, DenseData denseData, FeatureMeta featureMeta, SequentialPartition sequentialPartition) {
        super(params, denseData, featureMeta, sequentialPartition);
    }

    @Override // com.alibaba.alink.operator.common.tree.FeatureSplitter
    public double bestSplit(int i) {
        return (this.depth >= ((Integer) this.params.get(HasMaxDepth.MAX_DEPTH)).intValue() || this.partition.dataIndices.size() <= ((Integer) this.params.get(HasMinSamplesPerLeaf.MIN_SAMPLES_PER_LEAF)).intValue()) ? Criteria.INVALID_GAIN : (this.params.get(Criteria.Gain.GAIN) == Criteria.Gain.GINI || this.params.get(Criteria.Gain.GAIN) == Criteria.Gain.MSE) ? bestSplitCart(i) : bestSplitInfo(i);
    }

    @Override // com.alibaba.alink.operator.common.tree.FeatureSplitter
    public SequentialFeatureSplitter[][] split(FeatureSplitter[] featureSplitterArr) {
        if (!this.canSplit) {
            throw new IllegalStateException("The feature splitter should be calculated by `bestSplit`");
        }
        double[] dArr = new double[this.best.length];
        for (int i = 0; i < this.best.length; i++) {
            dArr[i] = this.best[i].getWeightSum() / this.total.getWeightSum();
        }
        return split(featureSplitterArr, this.partition.splitCategorical((int[]) this.data.getFeatureValues(this.featureMeta.getIndex()), this.best.length, this.splitPoint, dArr));
    }

    @Override // com.alibaba.alink.operator.common.tree.FeatureSplitter
    protected void count() {
        if (this.counted) {
            return;
        }
        int numCategorical = this.featureMeta.getNumCategorical();
        this.categoricalCriteria = new Criteria[numCategorical];
        LabelAccessor[] labelAccessorArr = new LabelAccessor[numCategorical];
        for (int i = 0; i < numCategorical; i++) {
            this.categoricalCriteria[i] = criteriaOf();
            labelAccessorArr[i] = labelAccessorOf(this.categoricalCriteria[i]);
        }
        this.missing = criteriaOf();
        LabelAccessor labelAccessorOf = labelAccessorOf(this.missing);
        int size = this.partition.dataIndices.size();
        int[] iArr = (int[]) this.data.getFeatureValues(this.featureMeta.getIndex());
        for (int i2 = 0; i2 < size; i2++) {
            int i3 = iArr[((Integer) this.partition.dataIndices.get(i2).f0).intValue()];
            if (DenseData.isCategoricalMissValue(i3)) {
                labelAccessorOf.add(i2);
            } else {
                labelAccessorArr[i3].add(i2);
            }
        }
        this.total = criteriaOf();
        for (int i4 = 0; i4 < numCategorical; i4++) {
            this.total.add(this.categoricalCriteria[i4]);
        }
        this.counted = true;
    }

    @Override // com.alibaba.alink.operator.common.tree.FeatureSplitter
    protected void fillNodeSplitPoint(Node node) {
        node.setCategoricalSplit(this.splitPoint);
    }

    private double bestSplitCartIncrement() {
        int numCategorical = this.featureMeta.getNumCategorical();
        ArrayList arrayList = new ArrayList(numCategorical);
        int[] iArr = new int[numCategorical];
        for (int i = 0; i < numCategorical; i++) {
            arrayList.add(Tuple2.of(Integer.valueOf(i), Double.valueOf(this.categoricalCriteria[i].mo606clone().toLabelCounter().normWithWeight().getDistributions()[0])));
            iArr[i] = 1;
        }
        arrayList.sort(Comparator.comparing(tuple2 -> {
            return (Double) tuple2.f1;
        }));
        Criteria.Gini gini = (Criteria.Gini) criteriaOf();
        Criteria.Gini gini2 = (Criteria.Gini) this.total.mo606clone();
        this.bestGain = Criteria.INVALID_GAIN;
        this.splitPoint = new int[numCategorical];
        this.best = new Criteria[2];
        for (int i2 = 0; i2 < numCategorical; i2++) {
            if (this.categoricalCriteria[i2].getNumInstances() == 0) {
                iArr[i2] = -1;
            } else {
                iArr[i2] = 0;
                gini.add(this.categoricalCriteria[i2]);
                gini2.subtract(this.categoricalCriteria[i2]);
                int numInstances = gini.getNumInstances() + this.missing.getNumInstances();
                int numInstances2 = gini2.getNumInstances() + this.missing.getNumInstances();
                int numInstances3 = this.total.getNumInstances() + this.missing.getNumInstances();
                if (gini.getNumInstances() != 0 && gini2.getNumInstances() != 0 && this.minSamplesPerLeaf <= numInstances && this.minSamplesPerLeaf <= numInstances2 && this.minSampleRatioPerChild <= numInstances / numInstances3 && this.minSampleRatioPerChild <= numInstances2 / numInstances3) {
                    double gain = this.total.gain(gini, gini2);
                    if (gain > this.bestGain && gain >= this.minInfoGain) {
                        this.bestGain = gain;
                        this.canSplit = true;
                        this.splitPoint = (int[]) iArr.clone();
                        this.best[0] = gini.mo606clone();
                        this.best[1] = gini2.mo606clone();
                    }
                }
            }
        }
        return this.bestGain;
    }

    private double bestSplitCart() {
        int numCategorical = this.featureMeta.getNumCategorical();
        this.bestGain = Criteria.INVALID_GAIN;
        this.splitPoint = new int[numCategorical];
        this.best = new Criteria[2];
        int[] iArr = new int[numCategorical];
        int i = 1 << (numCategorical - 1);
        Criteria criteriaOf = criteriaOf();
        Criteria criteriaOf2 = criteriaOf();
        for (int i2 = 1; i2 < i; i2++) {
            criteriaOf.reset();
            criteriaOf2.reset();
            for (int i3 = 0; i3 < numCategorical; i3++) {
                iArr[i3] = -1;
                if (this.categoricalCriteria[i3].getNumInstances() != 0) {
                    if ((i2 & (1 << i3)) != 0) {
                        iArr[i3] = 0;
                        criteriaOf.add(this.categoricalCriteria[i3]);
                    } else {
                        iArr[i3] = 1;
                        criteriaOf2.add(this.categoricalCriteria[i3]);
                    }
                }
            }
            int numInstances = criteriaOf.getNumInstances() + this.missing.getNumInstances();
            int numInstances2 = criteriaOf2.getNumInstances() + this.missing.getNumInstances();
            int numInstances3 = this.total.getNumInstances() + this.missing.getNumInstances();
            if (criteriaOf.getNumInstances() != 0 && criteriaOf2.getNumInstances() != 0 && this.minSamplesPerLeaf <= numInstances && this.minSamplesPerLeaf <= numInstances2 && this.minSampleRatioPerChild <= numInstances / numInstances3 && this.minSampleRatioPerChild <= numInstances2 / numInstances3) {
                double gain = this.total.gain(criteriaOf, criteriaOf2);
                if (gain > this.bestGain && gain >= this.minInfoGain) {
                    this.bestGain = gain;
                    this.canSplit = true;
                    this.splitPoint = (int[]) iArr.clone();
                    this.best[0] = criteriaOf.mo606clone();
                    this.best[1] = criteriaOf2.mo606clone();
                }
            }
        }
        return this.bestGain;
    }

    private double bestSplitCart(int i) {
        if (i + 2 >= this.maxLeaves) {
            return Criteria.INVALID_GAIN;
        }
        count();
        return this.missing.getNumInstances() == this.partition.dataIndices.size() ? Criteria.INVALID_GAIN : (this.params.get(Criteria.Gain.GAIN) == Criteria.Gain.GINI && this.data.labelMeta.getNumCategorical() == 2) ? bestSplitCartIncrement() : bestSplitCart();
    }

    private double bestSplitInfo(int i) {
        count();
        if (this.missing.getNumInstances() == this.partition.dataIndices.size()) {
            return Criteria.INVALID_GAIN;
        }
        int numCategorical = this.featureMeta.getNumCategorical();
        int i2 = 0;
        for (int i3 = 0; i3 < numCategorical; i3++) {
            if (this.categoricalCriteria[i3].getNumInstances() == 0 || this.categoricalCriteria[i3].getWeightSum() == Criteria.INVALID_GAIN) {
                i2++;
            } else if (this.minSamplesPerLeaf > this.categoricalCriteria[i3].getNumInstances() + this.missing.getNumInstances() || this.minSampleRatioPerChild > (this.categoricalCriteria[i3].getNumInstances() + this.missing.getNumInstances()) / (this.total.getNumInstances() + this.missing.getNumInstances())) {
                return Criteria.INVALID_GAIN;
            }
        }
        int i4 = numCategorical - i2;
        if (i4 < 2 || i + i4 >= this.maxLeaves) {
            return Criteria.INVALID_GAIN;
        }
        this.splitPoint = new int[numCategorical];
        this.best = new Criteria[i4];
        int i5 = 0;
        for (int i6 = 0; i6 < numCategorical; i6++) {
            this.splitPoint[i6] = -1;
            if (this.categoricalCriteria[i6].getNumInstances() == 0 || this.categoricalCriteria[i6].getWeightSum() == Criteria.INVALID_GAIN) {
                i5++;
            } else {
                int i7 = i6 - i5;
                this.splitPoint[i6] = i7;
                this.best[i7] = this.categoricalCriteria[i6];
            }
        }
        this.bestGain = this.total.gain(this.best);
        if (this.bestGain <= Criteria.INVALID_GAIN || this.bestGain < this.minInfoGain) {
            return Criteria.INVALID_GAIN;
        }
        this.canSplit = true;
        return this.bestGain;
    }
}
