package com.alibaba.alink.operator.common.tree.parallelcart;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.operator.common.tree.Node;
import com.alibaba.alink.operator.common.tree.parallelcart.data.DataUtil;
import com.alibaba.alink.operator.common.tree.parallelcart.data.Slice;
import org.apache.flink.ml.api.misc.param.Params;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/CalcFeatureGain.class */
public final class CalcFeatureGain extends ComputeFunction {
    private static final Logger LOG = LoggerFactory.getLogger(CalcFeatureGain.class);
    private static final long serialVersionUID = 3072216204653127272L;
    private HistogramFeatureSplitter[] featureSplitters;

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        LOG.info("taskId: {}, {} start", Integer.valueOf(comContext.getTaskId()), CalcFeatureGain.class.getSimpleName());
        BoostingObjs boostingObjs = (BoostingObjs) comContext.getObj(InitBoostingObjs.BOOSTING_OBJS);
        HistogramBaseTreeObjs histogramBaseTreeObjs = (HistogramBaseTreeObjs) comContext.getObj(InitTreeObjs.TREE);
        double[] dArr = (double[]) comContext.getObj("histogram");
        if (comContext.getStepNo() == 1) {
            comContext.putObj("best", new Node[histogramBaseTreeObjs.maxNodeSize]);
            this.featureSplitters = new HistogramFeatureSplitter[boostingObjs.data.getN()];
            for (int i = 0; i < boostingObjs.data.getN(); i++) {
                this.featureSplitters[i] = createFeatureSplitter(boostingObjs.data.getFeatureMetas()[i].getType() == FeatureMeta.FeatureType.CATEGORICAL, boostingObjs.params, boostingObjs.data.getFeatureMetas()[i], histogramBaseTreeObjs.compareIndex4Categorical);
            }
        }
        int i2 = 0;
        for (NodeInfoPair nodeInfoPair : histogramBaseTreeObjs.queue) {
            i2 += boostingObjs.numBaggingFeatures;
            if (nodeInfoPair.big != null) {
                i2 += boostingObjs.numBaggingFeatures;
            }
        }
        DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
        int startPos = (int) defaultDistributedInfo.startPos(comContext.getTaskId(), comContext.getNumTask(), i2);
        int localRowCnt = startPos + ((int) defaultDistributedInfo.localRowCnt(comContext.getTaskId(), comContext.getNumTask(), i2));
        int i3 = 0;
        int i4 = 0;
        Node[] nodeArr = (Node[]) comContext.getObj("best");
        int i5 = 0;
        for (NodeInfoPair nodeInfoPair2 : histogramBaseTreeObjs.queue) {
            nodeArr[i5] = null;
            for (int i6 : nodeInfoPair2.small.baggingFeatures) {
                if (i3 >= startPos && i3 < localRowCnt) {
                    this.featureSplitters[i6].reset(dArr, new Slice(i4, i4 + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[i6], histogramBaseTreeObjs.useMissing)), nodeInfoPair2.small.depth);
                    double bestSplit = this.featureSplitters[i6].bestSplit(histogramBaseTreeObjs.leaves.size());
                    if (nodeArr[i5] == null || (this.featureSplitters[i6].canSplit() && bestSplit > nodeArr[i5].getGain())) {
                        nodeArr[i5] = new Node();
                        this.featureSplitters[i6].fillNode(nodeArr[i5]);
                    }
                    i4 += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[i6], histogramBaseTreeObjs.useMissing);
                }
                i3++;
            }
            i5++;
            if (nodeInfoPair2.big != null) {
                nodeArr[i5] = null;
                for (int i7 : nodeInfoPair2.big.baggingFeatures) {
                    if (i3 >= startPos && i3 < localRowCnt) {
                        this.featureSplitters[i7].reset(dArr, new Slice(i4, i4 + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[i7], histogramBaseTreeObjs.useMissing)), nodeInfoPair2.big.depth);
                        double bestSplit2 = this.featureSplitters[i7].bestSplit(histogramBaseTreeObjs.leaves.size());
                        if (nodeArr[i5] == null || (this.featureSplitters[i7].canSplit() && bestSplit2 > nodeArr[i5].getGain())) {
                            nodeArr[i5] = new Node();
                            this.featureSplitters[i7].fillNode(nodeArr[i5]);
                        }
                        i4 += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[i7], histogramBaseTreeObjs.useMissing);
                    }
                    i3++;
                }
                i5++;
            }
        }
        comContext.putObj("bestLength", Integer.valueOf(i5));
        LOG.info("taskId: {}, {} end", Integer.valueOf(comContext.getTaskId()), CalcFeatureGain.class.getSimpleName());
    }

    private HistogramFeatureSplitter createFeatureSplitter(boolean z, Params params, FeatureMeta featureMeta, Integer[] numArr) {
        return z ? new HistogramCategoricalFeatureSplitter(params, featureMeta, numArr) : new HistogramContinuousFeatureSplitter(params, featureMeta);
    }
}
