package com.alibaba.alink.operator.common.tree.parallelcart;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.operator.common.tree.parallelcart.NodeInfoPair;
import com.alibaba.alink.operator.common.tree.parallelcart.leafscoreupdater.LeafScoreUpdater;
import com.alibaba.alink.operator.common.tree.parallelcart.leafscoreupdater.LeafScoreUpdaterFactory;
import com.alibaba.alink.operator.common.tree.parallelcart.leafscoreupdater.LeafScoreUpdaterType;
import com.alibaba.alink.params.classification.GbdtTrainParams;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/UpdateLeafScore.class */
public final class UpdateLeafScore extends ComputeFunction {
    private static final Logger LOG = LoggerFactory.getLogger(UpdateLeafScore.class);
    private static final long serialVersionUID = -1147524895767655755L;

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        BoostingObjs boostingObjs = (BoostingObjs) comContext.getObj(InitBoostingObjs.BOOSTING_OBJS);
        if (comContext.getStepNo() == 1) {
            comContext.putObj("updater", LeafScoreUpdaterFactory.createLeafScoreUpdater((LeafScoreUpdaterType) boostingObjs.params.get(LeafScoreUpdaterType.LEAF_SCORE_UPDATER_TYPE)));
        }
        if (boostingObjs.inWeakLearner) {
            return;
        }
        LOG.info("taskId: {}, {} start", Integer.valueOf(comContext.getTaskId()), UpdateLeafScore.class.getSimpleName());
        HistogramBaseTreeObjs histogramBaseTreeObjs = (HistogramBaseTreeObjs) comContext.getObj(InitTreeObjs.TREE);
        LeafScoreUpdater leafScoreUpdater = (LeafScoreUpdater) comContext.getObj("updater");
        double doubleValue = ((Double) boostingObjs.params.get(GbdtTrainParams.LEARNING_RATE)).doubleValue();
        for (NodeInfoPair.NodeInfo nodeInfo : histogramBaseTreeObjs.leaves) {
            leafScoreUpdater.update(nodeInfo.node.getCounter().getWeightSum(), nodeInfo.node.getCounter().getDistributions(), doubleValue);
        }
        LOG.info("taskId: {}, {} end", Integer.valueOf(comContext.getTaskId()), UpdateLeafScore.class.getSimpleName());
    }
}
