package com.alibaba.alink.operator.common.tree.parallelcart;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.operator.common.tree.parallelcart.data.Data;
import com.alibaba.alink.operator.common.tree.parallelcart.data.DataUtil;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.GBRankLoss;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LambdaLoss;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LeastSquare;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LogLoss;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.RankingLossFunc;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.UnaryLossFuncWithPrior;
import com.alibaba.alink.params.classification.GbdtTrainParams;
import com.alibaba.alink.params.shared.tree.HasSeed;
import java.util.BitSet;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Executors;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/InitBoostingObjs.class */
public final class InitBoostingObjs extends ComputeFunction {
    private static final Logger LOG = LoggerFactory.getLogger(InitBoostingObjs.class);
    private static final long serialVersionUID = 3534441875143136651L;
    public static final String BOOSTING_OBJS = "boostingObjs";
    public static final String FEATURE_METAS = "featureMetas";
    public static final String TRAIN_DATA = "trainData";
    private final Params params;

    public InitBoostingObjs(Params params) {
        this.params = params;
    }

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        if (comContext.getStepNo() != 1) {
            return;
        }
        LOG.info("taskId: {}, {} start", Integer.valueOf(comContext.getTaskId()), InitBoostingObjs.class.getSimpleName());
        List<Row> list = (List) comContext.getObj("trainData");
        boolean booleanValue = ((Boolean) this.params.get(BaseGbdtTrainBatchOp.USE_EPSILON_APPRO_QUANTILE)).booleanValue();
        Data createData = DataUtil.createData(this.params, (List) comContext.getObj(FEATURE_METAS), list == null ? 0 : list.size(), booleanValue);
        if (booleanValue) {
            createData.loadFromRowWithContinues(list);
        } else {
            createData.loadFromRow(list);
        }
        comContext.removeObj(FEATURE_METAS);
        comContext.removeObj("trainData");
        LOG.info("taskId: {}, data shape, M: {}, N: {}", new Object[]{Integer.valueOf(comContext.getTaskId()), Integer.valueOf(createData.getM()), Integer.valueOf(createData.getN())});
        BoostingObjs boostingObjs = new BoostingObjs();
        boostingObjs.params = this.params;
        boostingObjs.data = createData;
        if (LossUtils.isRanking((LossType) this.params.get(LossUtils.LOSS_TYPE))) {
            boostingObjs.rankingLoss = createRankingLoss((LossType) this.params.get(LossUtils.LOSS_TYPE), this.params, createData.getQueryIdOffset(), createData.getLabels(), createData.getWeights());
        } else {
            boostingObjs.loss = createUnaryLoss((LossType) this.params.get(LossUtils.LOSS_TYPE), (Tuple2) ((List) comContext.getObj("gbdt.y.sum")).get(0));
            boostingObjs.prior = boostingObjs.loss.prior();
            boostingObjs.numBaggingInstances = (int) Math.min(boostingObjs.data.getM(), Math.ceil(boostingObjs.data.getM() * ((Double) this.params.get(GbdtTrainParams.SUBSAMPLING_RATIO)).doubleValue()));
        }
        boostingObjs.instanceRandomizer = new Random(comContext.getTaskId() + ((Long) this.params.get(HasSeed.SEED)).longValue());
        boostingObjs.featureRandomizer = new Random(((Long) this.params.get(HasSeed.SEED)).longValue());
        boostingObjs.numBaggingFeatures = baggingFeatureCount(boostingObjs.data.getN());
        boostingObjs.featureIndices = new int[createData.getN()];
        for (int i = 0; i < createData.getN(); i++) {
            boostingObjs.featureIndices[i] = i;
        }
        boostingObjs.indices = new int[boostingObjs.data.getM()];
        boostingObjs.baggingFlags = new BitSet(boostingObjs.data.getM());
        boostingObjs.pred = new double[boostingObjs.data.getM()];
        for (int i2 = 0; i2 < boostingObjs.data.getM(); i2++) {
            boostingObjs.indices[i2] = i2;
            boostingObjs.pred[i2] = boostingObjs.prior;
        }
        boostingObjs.inWeakLearner = false;
        boostingObjs.numBoosting = 0;
        boostingObjs.executorService = Executors.newFixedThreadPool(8);
        comContext.putObj(BOOSTING_OBJS, boostingObjs);
        LOG.info("taskId: {}, {} end", Integer.valueOf(comContext.getTaskId()), InitBoostingObjs.class.getSimpleName());
    }

    private int baggingFeatureCount(int i) {
        return Math.max(1, Math.min((int) (((Double) this.params.get(GbdtTrainParams.FEATURE_SUBSAMPLING_RATIO)).doubleValue() * i), i));
    }

    private UnaryLossFuncWithPrior createUnaryLoss(LossType lossType, Tuple2<Double, Long> tuple2) {
        switch (lossType) {
            case LOG_LOSS:
                return new LogLoss(((Double) tuple2.f0).doubleValue(), ((Long) tuple2.f1).doubleValue() - ((Double) tuple2.f0).doubleValue());
            case LEASE_SQUARE:
                return new LeastSquare(((Double) tuple2.f0).doubleValue(), ((Long) tuple2.f1).doubleValue());
            default:
                throw new UnsupportedOperationException("Unsupported loss.");
        }
    }

    private RankingLossFunc createRankingLoss(LossType lossType, Params params, int[] iArr, double[] dArr, double[] dArr2) {
        switch (lossType) {
            case GBRANK:
                return new GBRankLoss(params, iArr, dArr, dArr2);
            case LAMBDA_NDCG:
                return new LambdaLoss(params, LambdaLoss.LambdaType.NDCG, iArr, dArr, dArr2);
            case LAMBDA_DCG:
                return new LambdaLoss(params, LambdaLoss.LambdaType.DCG, iArr, dArr, dArr2);
            default:
                throw new UnsupportedOperationException("Unsupported loss.");
        }
    }
}
