package com.alibaba.alink.operator.common.tree.parallelcart;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.operator.common.tree.parallelcart.EpsilonApproQuantile;
import com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster;
import com.alibaba.alink.operator.common.tree.parallelcart.data.DataUtil;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils;
import com.alibaba.alink.params.regression.GbdtRegTrainParams;
import java.util.Arrays;
import java.util.BitSet;
import java.util.concurrent.Future;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/ConstructLocalHistogram.class */
public class ConstructLocalHistogram extends ComputeFunction {
    private static final Logger LOG = LoggerFactory.getLogger(ConstructLocalHistogram.class);
    private static final long serialVersionUID = -325487480296758683L;
    private static final int STEP = 4;
    private BitSet featureValid;
    private int[] aligned;
    private int[] validFeatureOffset;
    private double[] featureSplitHistogram;
    private Future<?>[] results;
    private boolean useInstanceCount;
    private final DefaultDistributedInfo distributedInfo = new DefaultDistributedInfo();

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        LOG.info("taskId: {}, {} start", Integer.valueOf(comContext.getTaskId()), ConstructLocalHistogram.class.getSimpleName());
        BoostingObjs boostingObjs = (BoostingObjs) comContext.getObj(InitBoostingObjs.BOOSTING_OBJS);
        Booster booster = (Booster) comContext.getObj(Boosting.BOOSTER);
        HistogramBaseTreeObjs histogramBaseTreeObjs = (HistogramBaseTreeObjs) comContext.getObj(InitTreeObjs.TREE);
        if (comContext.getStepNo() == 1) {
            LOG.info("maxDepth: {}, maxLeaves: {}", boostingObjs.params.get(GbdtRegTrainParams.MAX_DEPTH), boostingObjs.params.get(GbdtRegTrainParams.MAX_LEAVES));
            int min = Math.min(histogramBaseTreeObjs.maxNodeSize * histogramBaseTreeObjs.maxFeatureBins * STEP * boostingObjs.numBaggingFeatures, histogramBaseTreeObjs.maxNodeSize * STEP * histogramBaseTreeObjs.allFeatureBins);
            int i = histogramBaseTreeObjs.maxNodeSize * STEP * histogramBaseTreeObjs.allFeatureBins;
            comContext.putObj("histogram", new double[min]);
            comContext.putObj("recvcnts", new int[comContext.getNumTask()]);
            this.featureSplitHistogram = new double[i];
            this.featureValid = new BitSet(boostingObjs.data.getN());
            this.aligned = new int[boostingObjs.data.getM()];
            this.validFeatureOffset = new int[boostingObjs.data.getN()];
            this.results = new Future[boostingObjs.data.getN()];
            this.useInstanceCount = LossUtils.useInstanceCount((LossType) boostingObjs.params.get(LossUtils.LOSS_TYPE));
        }
        if (!boostingObjs.inWeakLearner) {
            histogramBaseTreeObjs.initPerTree(boostingObjs, (EpsilonApproQuantile.WQSummary[]) comContext.getObj(BuildLocalSketch.SKETCH));
            boostingObjs.inWeakLearner = true;
        }
        int calcWithNodeIdCache = calcWithNodeIdCache(comContext, boostingObjs, booster, histogramBaseTreeObjs, (double[]) comContext.getObj("histogram"));
        int[] iArr = (int[]) comContext.getObj("recvcnts");
        Arrays.fill(iArr, 0);
        int i2 = 0;
        int i3 = 0;
        int startPos = (int) (this.distributedInfo.startPos(0, comContext.getNumTask(), calcWithNodeIdCache) + this.distributedInfo.localRowCnt(0, comContext.getNumTask(), calcWithNodeIdCache));
        for (NodeInfoPair nodeInfoPair : histogramBaseTreeObjs.queue) {
            for (int i4 : nodeInfoPair.small.baggingFeatures) {
                i3++;
                while (i3 > startPos) {
                    i2++;
                    startPos = (int) (this.distributedInfo.startPos(i2, comContext.getNumTask(), calcWithNodeIdCache) + this.distributedInfo.localRowCnt(i2, comContext.getNumTask(), calcWithNodeIdCache));
                }
                int i5 = i2;
                iArr[i5] = iArr[i5] + (DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[i4], histogramBaseTreeObjs.useMissing) * STEP);
            }
            if (nodeInfoPair.big != null) {
                for (int i6 : nodeInfoPair.big.baggingFeatures) {
                    i3++;
                    while (i3 > startPos) {
                        i2++;
                        startPos = (int) (this.distributedInfo.startPos(i2, comContext.getNumTask(), calcWithNodeIdCache) + this.distributedInfo.localRowCnt(i2, comContext.getNumTask(), calcWithNodeIdCache));
                    }
                    int i7 = i2;
                    iArr[i7] = iArr[i7] + (DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[i6], histogramBaseTreeObjs.useMissing) * STEP);
                }
            }
        }
        LOG.info("taskId: {}, {} end", Integer.valueOf(comContext.getTaskId()), ConstructLocalHistogram.class.getSimpleName());
    }

    public int calcWithNodeIdCache(ComContext comContext, BoostingObjs boostingObjs, Booster booster, HistogramBaseTreeObjs histogramBaseTreeObjs, double[] dArr) {
        int i = 0;
        int i2 = 0;
        this.featureValid.clear();
        Arrays.fill(histogramBaseTreeObjs.nodeIdCache, -1);
        histogramBaseTreeObjs.baggingFeaturePool.reset();
        for (NodeInfoPair nodeInfoPair : histogramBaseTreeObjs.queue) {
            nodeInfoPair.small.baggingFeatures = Bagging.sampleFeatures(boostingObjs, histogramBaseTreeObjs.baggingFeaturePool);
            for (int i3 = 0; i3 < nodeInfoPair.small.baggingFeatures.length; i3++) {
                this.featureValid.set(nodeInfoPair.small.baggingFeatures[i3], true);
            }
            for (int i4 = nodeInfoPair.small.slice.start; i4 < nodeInfoPair.small.slice.end; i4++) {
                histogramBaseTreeObjs.nodeIdCache[boostingObjs.indices[i4]] = i;
            }
            i++;
            i2 += boostingObjs.numBaggingFeatures;
            if (nodeInfoPair.big != null) {
                nodeInfoPair.big.baggingFeatures = Bagging.sampleFeatures(boostingObjs, histogramBaseTreeObjs.baggingFeaturePool);
                for (int i5 = 0; i5 < nodeInfoPair.big.baggingFeatures.length; i5++) {
                    this.featureValid.set(nodeInfoPair.big.baggingFeatures[i5], true);
                }
                for (int i6 = nodeInfoPair.big.slice.start; i6 < nodeInfoPair.big.slice.end; i6++) {
                    histogramBaseTreeObjs.nodeIdCache[boostingObjs.indices[i6]] = i;
                }
                i++;
                i2 += boostingObjs.numBaggingFeatures;
            }
        }
        if (((Boolean) boostingObjs.params.get(BaseGbdtTrainBatchOp.USE_EPSILON_APPRO_QUANTILE)).booleanValue()) {
            EpsilonApproQuantile.WQSummary[] wQSummaryArr = (EpsilonApproQuantile.WQSummary[]) comContext.getObj(BuildLocalSketch.SKETCH);
            int i7 = 0;
            for (int i8 = 0; i8 < boostingObjs.data.getN(); i8++) {
                FeatureMeta featureMeta = boostingObjs.data.getFeatureMetas()[i8];
                this.validFeatureOffset[i8] = i7;
                if (this.featureValid.get(i8)) {
                    i7 += DataUtil.getFeatureCategoricalSize(featureMeta, histogramBaseTreeObjs.useMissing);
                }
            }
            boostingObjs.data.constructHistogramWithWQSummary(this.useInstanceCount, i, this.featureValid, histogramBaseTreeObjs.nodeIdCache, this.validFeatureOffset, booster.getGradients(), booster.getHessions(), booster.getWeights(), wQSummaryArr, boostingObjs.executorService, this.results, this.featureSplitHistogram);
        } else {
            int i9 = 0;
            for (int i10 = 0; i10 < boostingObjs.data.getN(); i10++) {
                this.validFeatureOffset[i10] = i9;
                if (this.featureValid.get(i10)) {
                    i9 += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[i10], histogramBaseTreeObjs.useMissing);
                }
            }
            int i11 = 0;
            for (int i12 = 0; i12 < boostingObjs.data.getM(); i12++) {
                if (histogramBaseTreeObjs.nodeIdCache[i12] >= 0) {
                    this.aligned[i11] = i12;
                    i11++;
                }
            }
            LOG.info("taskId: {}, calcWithNodeIdCache start", Integer.valueOf(comContext.getTaskId()));
            boostingObjs.data.constructHistogram(this.useInstanceCount, i, i11, this.featureValid, histogramBaseTreeObjs.nodeIdCache, this.validFeatureOffset, this.aligned, booster.getGradients(), booster.getHessions(), booster.getWeights(), boostingObjs.executorService, this.results, this.featureSplitHistogram);
        }
        LOG.info("taskId: {}, calcWithNodeIdCache end", Integer.valueOf(comContext.getTaskId()));
        int i13 = 0;
        int i14 = 0;
        for (NodeInfoPair nodeInfoPair2 : histogramBaseTreeObjs.queue) {
            for (int i15 = 0; i15 < nodeInfoPair2.small.baggingFeatures.length; i15++) {
                int featureCategoricalSize = DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[nodeInfoPair2.small.baggingFeatures[i15]], histogramBaseTreeObjs.useMissing);
                System.arraycopy(this.featureSplitHistogram, ((this.validFeatureOffset[nodeInfoPair2.small.baggingFeatures[i15]] * i) + (i14 * featureCategoricalSize)) * STEP, dArr, i13 * STEP, featureCategoricalSize * STEP);
                i13 += featureCategoricalSize;
            }
            i14++;
            if (nodeInfoPair2.big != null) {
                for (int i16 = 0; i16 < nodeInfoPair2.big.baggingFeatures.length; i16++) {
                    int featureCategoricalSize2 = DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[nodeInfoPair2.big.baggingFeatures[i16]], histogramBaseTreeObjs.useMissing);
                    System.arraycopy(this.featureSplitHistogram, ((this.validFeatureOffset[nodeInfoPair2.big.baggingFeatures[i16]] * i) + (i14 * featureCategoricalSize2)) * STEP, dArr, i13 * STEP, featureCategoricalSize2 * STEP);
                    i13 += featureCategoricalSize2;
                }
                i14++;
            }
        }
        LOG.info("taskId: {}, sumFeatureCount: {}", Integer.valueOf(comContext.getTaskId()), Integer.valueOf(i2));
        return i2;
    }
}
