package com.alibaba.alink.operator.common.tree.parallelcart;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.operator.common.tree.parallelcart.NodeInfoPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/UpdatePredictionScore.class */
public final class UpdatePredictionScore extends ComputeFunction {
    private static final Logger LOG = LoggerFactory.getLogger(UpdatePredictionScore.class);
    private static final long serialVersionUID = -1747045532702426860L;

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        BoostingObjs boostingObjs = (BoostingObjs) comContext.getObj(InitBoostingObjs.BOOSTING_OBJS);
        if (boostingObjs.inWeakLearner) {
            return;
        }
        LOG.info("taskId: {}, {} start", Integer.valueOf(comContext.getTaskId()), UpdatePredictionScore.class.getSimpleName());
        for (NodeInfoPair.NodeInfo nodeInfo : ((HistogramBaseTreeObjs) comContext.getObj(InitTreeObjs.TREE)).leaves) {
            double d = nodeInfo.node.getCounter().getDistributions()[0];
            for (int i = nodeInfo.slice.start; i < nodeInfo.slice.end; i++) {
                double[] dArr = boostingObjs.pred;
                int i2 = boostingObjs.indices[i];
                dArr[i2] = dArr[i2] + d;
            }
            for (int i3 = nodeInfo.oob.start; i3 < nodeInfo.oob.end; i3++) {
                double[] dArr2 = boostingObjs.pred;
                int i4 = boostingObjs.indices[i3];
                dArr2[i4] = dArr2[i4] + d;
            }
        }
        LOG.info("taskId: {}, {} end", Integer.valueOf(comContext.getTaskId()), UpdatePredictionScore.class.getSimpleName());
    }
}
