package com.alibaba.alink.operator.common.tree.parallelcart;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType;
import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils;
import com.alibaba.alink.operator.common.tree.paralleltree.TreeObj;
import com.alibaba.alink.params.classification.GbdtTrainParams;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/Bagging.class */
public final class Bagging extends ComputeFunction {
    private static final Logger LOG = LoggerFactory.getLogger(Bagging.class);
    private static final long serialVersionUID = -4306518497657978646L;
    private int[] queryBaggingIndices;
    private int queryBaggingCnt;

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/Bagging$BaggingFeaturePool.class */
    public static final class BaggingFeaturePool {
        private final int[][] pool;
        private int cursor;

        public BaggingFeaturePool(int i, int i2) {
            this.pool = new int[i][i2];
        }

        public int[] get() {
            int[][] iArr = this.pool;
            int i = this.cursor;
            this.cursor = i + 1;
            return iArr[i];
        }

        public void reset() {
            this.cursor = 0;
        }
    }

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        BoostingObjs boostingObjs = (BoostingObjs) comContext.getObj(InitBoostingObjs.BOOSTING_OBJS);
        if (boostingObjs.inWeakLearner) {
            return;
        }
        LOG.info("taskId: {}, {} start", Integer.valueOf(comContext.getTaskId()), Bagging.class.getSimpleName());
        if (LossUtils.isRanking((LossType) boostingObjs.params.get(LossUtils.LOSS_TYPE))) {
            int length = boostingObjs.data.getQueryIdOffset().length - 1;
            if (comContext.getStepNo() == 1) {
                this.queryBaggingIndices = new int[length];
                for (int i = 0; i < length; i++) {
                    this.queryBaggingIndices[i] = i;
                }
                this.queryBaggingCnt = (int) Math.min(length, Math.ceil(length * ((Double) boostingObjs.params.get(GbdtTrainParams.SUBSAMPLING_RATIO)).doubleValue()));
            }
            TreeObj.shuffle(this.queryBaggingIndices, boostingObjs.instanceRandomizer);
            int[] queryIdOffset = boostingObjs.data.getQueryIdOffset();
            int i2 = 0;
            for (int i3 = 0; i3 < this.queryBaggingCnt; i3++) {
                int i4 = this.queryBaggingIndices[i3];
                int i5 = queryIdOffset[i4];
                int i6 = queryIdOffset[i4 + 1];
                for (int i7 = i5; i7 < i6; i7++) {
                    int i8 = i2;
                    i2++;
                    boostingObjs.indices[i8] = i7;
                }
            }
            boostingObjs.numBaggingInstances = i2;
        } else {
            TreeObj.shuffle(boostingObjs.indices, boostingObjs.instanceRandomizer);
        }
        LOG.info("taskId: {}, {} end", Integer.valueOf(comContext.getTaskId()), Bagging.class.getSimpleName());
    }

    public static int[] sampleFeatures(BoostingObjs boostingObjs, BaggingFeaturePool baggingFeaturePool) {
        int[] iArr = baggingFeaturePool.get();
        TreeObj.shuffle(boostingObjs.featureIndices, boostingObjs.featureRandomizer);
        System.arraycopy(boostingObjs.featureIndices, 0, iArr, 0, boostingObjs.numBaggingFeatures);
        Arrays.sort(iArr);
        return iArr;
    }
}
