package com.alibaba.alink.operator.common.tree.parallelcart;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster;
import com.alibaba.alink.operator.common.tree.parallelcart.booster.BoosterFactory;
import com.alibaba.alink.operator.common.tree.parallelcart.booster.BoosterType;
import com.alibaba.alink.operator.common.tree.parallelcart.data.Slice;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/Boosting.class */
public final class Boosting extends ComputeFunction {
    private static final Logger LOG = LoggerFactory.getLogger(Boosting.class);
    private static final long serialVersionUID = 9179338616517355846L;
    public static final String BOOSTER = "booster";

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        BoostingObjs boostingObjs = (BoostingObjs) comContext.getObj(InitBoostingObjs.BOOSTING_OBJS);
        if (boostingObjs.inWeakLearner) {
            return;
        }
        LOG.info("taskId: {}, {} start", Integer.valueOf(comContext.getTaskId()), Boosting.class.getSimpleName());
        if (comContext.getStepNo() == 1) {
            comContext.putObj(BOOSTER, LossUtils.isRanking((LossType) boostingObjs.params.get(LossUtils.LOSS_TYPE)) ? BoosterFactory.createRankingBooster((BoosterType) boostingObjs.params.get(BoosterType.BOOSTER_TYPE), boostingObjs.rankingLoss, boostingObjs.data.getQueryIdOffset(), boostingObjs.data.getWeights(), new Slice(0, boostingObjs.data.getQueryIdOffset().length - 1), new Slice(0, boostingObjs.data.getM())) : BoosterFactory.createBooster((BoosterType) boostingObjs.params.get(BoosterType.BOOSTER_TYPE), boostingObjs.loss, boostingObjs.data.getWeights(), new Slice(0, boostingObjs.data.getM())));
        }
        ((Booster) comContext.getObj(BOOSTER)).boosting(boostingObjs, boostingObjs.data.getLabels(), boostingObjs.pred);
        boostingObjs.numBoosting++;
        LOG.info("taskId: {}, {} end", Integer.valueOf(comContext.getTaskId()), Boosting.class.getSimpleName());
    }
}
