package com.alibaba.alink.operator.common.tree.parallelcart;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.operator.common.tree.parallelcart.EpsilonApproQuantile;
import com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster;
import com.alibaba.alink.operator.common.tree.parallelcart.communication.AllReduceT;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/BuildLocalSketch.class */
public final class BuildLocalSketch extends ComputeFunction {
    public static final String SKETCH = "sketch";
    public static final String FEATURE_SKETCH_LENGTH = "featureSketchLength";
    private static final long serialVersionUID = 6357309989554983054L;
    EpsilonApproQuantile.SketchEntry[] entries;

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/BuildLocalSketch$SketchReducer.class */
    public static final class SketchReducer implements AllReduceT.SerializableBiConsumer<EpsilonApproQuantile.WQSummary[], EpsilonApproQuantile.WQSummary[]> {
        private static final long serialVersionUID = -5949438933636979107L;
        private final int maxSize;

        public SketchReducer(Params params) {
            this.maxSize = BuildLocalSketch.maxSize(((Double) params.get(BaseGbdtTrainBatchOp.SKETCH_RATIO)).doubleValue(), ((Double) params.get(BaseGbdtTrainBatchOp.SKETCH_EPS)).doubleValue());
        }

        @Override // java.util.function.BiConsumer
        public void accept(EpsilonApproQuantile.WQSummary[] wQSummaryArr, EpsilonApproQuantile.WQSummary[] wQSummaryArr2) {
            AkPreconditions.checkState((wQSummaryArr == null || wQSummaryArr2 == null || wQSummaryArr.length != wQSummaryArr2.length) ? false : true);
            for (int i = 0; i < wQSummaryArr.length; i++) {
                EpsilonApproQuantile.WQSummary wQSummary = new EpsilonApproQuantile.WQSummary();
                wQSummary.setCombine(wQSummaryArr[i], wQSummaryArr2[i]);
                wQSummaryArr[i].setPrune(wQSummary, this.maxSize);
            }
        }
    }

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        BoostingObjs boostingObjs = (BoostingObjs) comContext.getObj(InitBoostingObjs.BOOSTING_OBJS);
        if (boostingObjs.inWeakLearner) {
            comContext.putObj(FEATURE_SKETCH_LENGTH, 0);
            return;
        }
        int i = 0;
        for (int i2 = 0; i2 < boostingObjs.data.getN(); i2++) {
            if (boostingObjs.data.getFeatureMetas()[i2].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
                i++;
            }
        }
        if (comContext.getStepNo() == 1) {
            boostingObjs.data.sort();
            if (((EpsilonApproQuantile.WQSummary[]) comContext.getObj(SKETCH)) == null) {
                EpsilonApproQuantile.WQSummary[] wQSummaryArr = new EpsilonApproQuantile.WQSummary[i];
                for (int i3 = 0; i3 < i; i3++) {
                    wQSummaryArr[i3] = new EpsilonApproQuantile.WQSummary();
                }
                comContext.putObj(SKETCH, wQSummaryArr);
            }
        }
        EpsilonApproQuantile.WQSummary[] wQSummaryArr2 = (EpsilonApproQuantile.WQSummary[]) comContext.getObj(SKETCH);
        double doubleValue = ((Double) boostingObjs.params.get(BaseGbdtTrainBatchOp.SKETCH_EPS)).doubleValue();
        int maxSize = maxSize(((Double) boostingObjs.params.get(BaseGbdtTrainBatchOp.SKETCH_RATIO)).doubleValue(), doubleValue);
        AkPreconditions.checkState(maxSize > 0);
        if (this.entries == null) {
            this.entries = new EpsilonApproQuantile.SketchEntry[i];
            for (int i4 = 0; i4 < i; i4++) {
                EpsilonApproQuantile.SketchEntry sketchEntry = new EpsilonApproQuantile.SketchEntry();
                sketchEntry.sketch.limitSizeLevel(boostingObjs.data.getM(), doubleValue);
                this.entries[i4] = sketchEntry;
            }
        }
        boostingObjs.baggingFlags.clear();
        for (int i5 = 0; i5 < boostingObjs.numBaggingInstances; i5++) {
            boostingObjs.baggingFlags.set(boostingObjs.indices[i5]);
        }
        boostingObjs.data.createWQSummary(maxSize, doubleValue, this.entries, ((Booster) comContext.getObj(Boosting.BOOSTER)).getHessions(), boostingObjs.baggingFlags);
        for (int i6 = 0; i6 < this.entries.length; i6++) {
            this.entries[i6].sketch.getSummary(wQSummaryArr2[i6]);
        }
        comContext.putObj(FEATURE_SKETCH_LENGTH, Integer.valueOf(i));
    }

    public static int maxSize(double d, double d2) {
        return (int) (d / d2);
    }
}
