package com.alibaba.alink.operator.common.tree.parallelcart;

import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.operator.common.tree.FeatureSplitter;
import com.alibaba.alink.operator.common.tree.LabelAccessor;
import com.alibaba.alink.operator.common.tree.parallelcart.criteria.AlinkCriteria;
import com.alibaba.alink.operator.common.tree.parallelcart.criteria.CriteriaType;
import com.alibaba.alink.operator.common.tree.parallelcart.criteria.GBMTreeSplitCriteria;
import com.alibaba.alink.operator.common.tree.parallelcart.criteria.PaiCriteria;
import com.alibaba.alink.operator.common.tree.parallelcart.criteria.XGboostCriteria;
import com.alibaba.alink.operator.common.tree.parallelcart.data.Slice;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils;
import com.alibaba.alink.params.classification.GbdtTrainParams;
import java.util.Arrays;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/HistogramFeatureSplitter.class */
public abstract class HistogramFeatureSplitter extends FeatureSplitter {
    protected double[] featureHist;
    protected Slice slice;
    protected Criteria bestLeft;
    protected Criteria bestRight;
    protected final int maxDepth;
    protected final boolean useInstanceCnt;
    protected final boolean useMissing;
    protected int[] missingSplit;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/HistogramFeatureSplitter$AlinkCriteriaHistogramAccessor.class */
    public class AlinkCriteriaHistogramAccessor extends LabelAccessor implements CategoricalLabelSortable {
        static final int step = 4;
        AlinkCriteria criteria;

        AlinkCriteriaHistogramAccessor(AlinkCriteria alinkCriteria) {
            this.criteria = alinkCriteria;
        }

        @Override // com.alibaba.alink.operator.common.tree.parallelcart.HistogramFeatureSplitter.CategoricalLabelSortable
        public void sort4Categorical(Integer[] numArr, int i, int i2) {
            Arrays.sort(numArr, i, i2, (num, num2) -> {
                int intValue = (num.intValue() * step) + (HistogramFeatureSplitter.this.slice.start * step);
                int intValue2 = (num2.intValue() * step) + (HistogramFeatureSplitter.this.slice.start * step);
                return Double.compare(HistogramFeatureSplitter.this.featureHist[intValue] / HistogramFeatureSplitter.this.featureHist[intValue + 1], HistogramFeatureSplitter.this.featureHist[intValue2] / HistogramFeatureSplitter.this.featureHist[intValue2 + 1]);
            });
        }

        @Override // com.alibaba.alink.operator.common.tree.LabelAccessor
        public int size() {
            return HistogramFeatureSplitter.this.slice.end - HistogramFeatureSplitter.this.slice.start;
        }

        @Override // com.alibaba.alink.operator.common.tree.LabelAccessor
        public void add(int i) {
            int i2 = (i * step) + (HistogramFeatureSplitter.this.slice.start * step);
            this.criteria.add(HistogramFeatureSplitter.this.featureHist[i2], HistogramFeatureSplitter.this.featureHist[i2 + 1], HistogramFeatureSplitter.this.featureHist[i2 + 2], (int) HistogramFeatureSplitter.this.featureHist[i2 + 3]);
        }

        @Override // com.alibaba.alink.operator.common.tree.LabelAccessor
        public void sub(int i) {
            int i2 = (i * step) + (HistogramFeatureSplitter.this.slice.start * step);
            this.criteria.subtract(HistogramFeatureSplitter.this.featureHist[i2], HistogramFeatureSplitter.this.featureHist[i2 + 1], HistogramFeatureSplitter.this.featureHist[i2 + 2], (int) HistogramFeatureSplitter.this.featureHist[i2 + 3]);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/HistogramFeatureSplitter$CategoricalLabelSortable.class */
    interface CategoricalLabelSortable {
        void sort4Categorical(Integer[] numArr, int i, int i2);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/HistogramFeatureSplitter$PaiCriteriaHistogramAccessor.class */
    public class PaiCriteriaHistogramAccessor extends LabelAccessor implements CategoricalLabelSortable {
        static final int step = 4;
        PaiCriteria criteria;

        PaiCriteriaHistogramAccessor(PaiCriteria paiCriteria) {
            this.criteria = paiCriteria;
        }

        @Override // com.alibaba.alink.operator.common.tree.parallelcart.HistogramFeatureSplitter.CategoricalLabelSortable
        public void sort4Categorical(Integer[] numArr, int i, int i2) {
            Arrays.sort(numArr, i, i2, (num, num2) -> {
                return Double.compare(HistogramFeatureSplitter.this.featureHist[(num.intValue() * step) + (HistogramFeatureSplitter.this.slice.start * step)], HistogramFeatureSplitter.this.featureHist[(num2.intValue() * step) + (HistogramFeatureSplitter.this.slice.start * step)]);
            });
        }

        @Override // com.alibaba.alink.operator.common.tree.LabelAccessor
        public int size() {
            return HistogramFeatureSplitter.this.slice.end - HistogramFeatureSplitter.this.slice.start;
        }

        @Override // com.alibaba.alink.operator.common.tree.LabelAccessor
        public void add(int i) {
            int i2 = (i * step) + (HistogramFeatureSplitter.this.slice.start * step);
            this.criteria.add(HistogramFeatureSplitter.this.featureHist[i2], HistogramFeatureSplitter.this.featureHist[i2 + 1], HistogramFeatureSplitter.this.featureHist[i2 + 2], (int) HistogramFeatureSplitter.this.featureHist[i2 + 3]);
        }

        @Override // com.alibaba.alink.operator.common.tree.LabelAccessor
        public void sub(int i) {
            int i2 = (i * step) + (HistogramFeatureSplitter.this.slice.start * step);
            this.criteria.subtract(HistogramFeatureSplitter.this.featureHist[i2], HistogramFeatureSplitter.this.featureHist[i2 + 1], HistogramFeatureSplitter.this.featureHist[i2 + 2], (int) HistogramFeatureSplitter.this.featureHist[i2 + 3]);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/HistogramFeatureSplitter$XGBoostCriteriaHistogramAccessor.class */
    public class XGBoostCriteriaHistogramAccessor extends LabelAccessor implements CategoricalLabelSortable {
        static final int step = 4;
        XGboostCriteria criteria;

        XGBoostCriteriaHistogramAccessor(XGboostCriteria xGboostCriteria) {
            this.criteria = xGboostCriteria;
        }

        @Override // com.alibaba.alink.operator.common.tree.parallelcart.HistogramFeatureSplitter.CategoricalLabelSortable
        public void sort4Categorical(Integer[] numArr, int i, int i2) {
            Arrays.sort(numArr, i, i2, (num, num2) -> {
                int intValue = (num.intValue() * step) + (HistogramFeatureSplitter.this.slice.start * step);
                int intValue2 = (num2.intValue() * step) + (HistogramFeatureSplitter.this.slice.start * step);
                return Double.compare(HistogramFeatureSplitter.this.featureHist[intValue] / HistogramFeatureSplitter.this.featureHist[intValue + 1], HistogramFeatureSplitter.this.featureHist[intValue2] / HistogramFeatureSplitter.this.featureHist[intValue2 + 1]);
            });
        }

        @Override // com.alibaba.alink.operator.common.tree.LabelAccessor
        public int size() {
            return HistogramFeatureSplitter.this.slice.end - HistogramFeatureSplitter.this.slice.start;
        }

        @Override // com.alibaba.alink.operator.common.tree.LabelAccessor
        public void add(int i) {
            int i2 = (i * step) + (HistogramFeatureSplitter.this.slice.start * step);
            this.criteria.add(HistogramFeatureSplitter.this.featureHist[i2], HistogramFeatureSplitter.this.featureHist[i2 + 1], HistogramFeatureSplitter.this.featureHist[i2 + 2], (int) HistogramFeatureSplitter.this.featureHist[i2 + 3]);
        }

        @Override // com.alibaba.alink.operator.common.tree.LabelAccessor
        public void sub(int i) {
            int i2 = (i * step) + (HistogramFeatureSplitter.this.slice.start * step);
            this.criteria.subtract(HistogramFeatureSplitter.this.featureHist[i2], HistogramFeatureSplitter.this.featureHist[i2 + 1], HistogramFeatureSplitter.this.featureHist[i2 + 2], (int) HistogramFeatureSplitter.this.featureHist[i2 + 3]);
        }
    }

    public HistogramFeatureSplitter(Params params, FeatureMeta featureMeta) {
        super(params, featureMeta);
        this.maxDepth = ((Integer) params.get(GbdtTrainParams.MAX_DEPTH)).intValue();
        this.useInstanceCnt = LossUtils.useInstanceCount((LossType) params.get(LossUtils.LOSS_TYPE));
        this.useMissing = ((Boolean) params.get(BaseGbdtTrainBatchOp.USE_MISSING)).booleanValue();
    }

    public void reset(double[] dArr, Slice slice, int i) {
        this.featureHist = dArr;
        this.slice = slice;
        this.depth = i;
        this.canSplit = false;
        this.counted = false;
    }

    @Override // com.alibaba.alink.operator.common.tree.FeatureSplitter
    public FeatureSplitter[][] split(FeatureSplitter[] featureSplitterArr) {
        throw new UnsupportedOperationException("Unsupported.");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.operator.common.tree.FeatureSplitter
    public void count() {
        if (this.counted) {
            return;
        }
        this.total = criteriaOf();
        this.missing = criteriaOf();
        LabelAccessor labelAccessorOf = labelAccessorOf(this.total);
        for (int i = 0; i < labelAccessorOf.size(); i++) {
            labelAccessorOf.add(i);
        }
        if (this.useMissing) {
            labelAccessorOf(this.missing).add(this.featureMeta.getMissingIndex());
            this.total.subtract(this.missing);
        }
        this.counted = true;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public GBMTreeSplitCriteria criteriaOf() {
        switch ((CriteriaType) this.params.get(CriteriaType.CRITERIA_TYPE)) {
            case PAI:
                return new PaiCriteria(Criteria.INVALID_GAIN, 0, Criteria.INVALID_GAIN, Criteria.INVALID_GAIN);
            case ALINK:
                return new AlinkCriteria(Criteria.INVALID_GAIN, 0, Criteria.INVALID_GAIN, Criteria.INVALID_GAIN);
            case XGBOOST:
                return new XGboostCriteria(((Double) this.params.get(GbdtTrainParams.LAMBDA)).doubleValue(), ((Double) this.params.get(GbdtTrainParams.GAMMA)).doubleValue(), Criteria.INVALID_GAIN, 0, Criteria.INVALID_GAIN, Criteria.INVALID_GAIN);
            default:
                throw new IllegalStateException("There should be set the gain type");
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public LabelAccessor labelAccessorOf(Criteria criteria) {
        if (criteria instanceof PaiCriteria) {
            return new PaiCriteriaHistogramAccessor((PaiCriteria) criteria);
        }
        if (criteria instanceof AlinkCriteria) {
            return new AlinkCriteriaHistogramAccessor((AlinkCriteria) criteria);
        }
        if (criteria instanceof XGboostCriteria) {
            return new XGBoostCriteriaHistogramAccessor((XGboostCriteria) criteria);
        }
        throw new IllegalStateException("The criteria type must be pai, alink or xgboost.");
    }
}
