package com.alibaba.alink.operator.common.tree.parallelcart;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.operator.common.tree.Node;
import com.alibaba.alink.operator.common.tree.parallelcart.EpsilonApproQuantile;
import com.alibaba.alink.operator.common.tree.parallelcart.NodeInfoPair;
import com.alibaba.alink.operator.common.tree.parallelcart.data.Data;
import com.alibaba.alink.operator.common.tree.parallelcart.data.Slice;
import com.alibaba.alink.params.classification.GbdtTrainParams;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/SplitInstances.class */
public final class SplitInstances extends ComputeFunction {
    private static final Logger LOG;
    private static final long serialVersionUID = -1127287176494832930L;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        LOG.info("taskId: {}, {} start", Integer.valueOf(comContext.getTaskId()), SplitInstances.class.getSimpleName());
        BoostingObjs boostingObjs = (BoostingObjs) comContext.getObj(InitBoostingObjs.BOOSTING_OBJS);
        HistogramBaseTreeObjs histogramBaseTreeObjs = (HistogramBaseTreeObjs) comContext.getObj(InitTreeObjs.TREE);
        Node[] nodeArr = (Node[]) comContext.getObj("best");
        int intValue = ((Integer) boostingObjs.params.get(GbdtTrainParams.MAX_LEAVES)).intValue();
        int i = 0;
        int size = histogramBaseTreeObjs.queue.size();
        int i2 = 0;
        Iterator<NodeInfoPair> it = histogramBaseTreeObjs.queue.iterator();
        while (it.hasNext()) {
            i2++;
            if (it.next().big != null) {
                i2++;
            }
        }
        for (int i3 = 0; i3 < size; i3++) {
            NodeInfoPair poll = histogramBaseTreeObjs.queue.poll();
            if (!$assertionsDisabled && poll == null) {
                throw new AssertionError();
            }
            i2--;
            poll.small.node.copy(nodeArr[i]);
            if (poll.small.node.isLeaf() || histogramBaseTreeObjs.leaves.size() + i2 + 2 > intValue) {
                poll.small.node.makeLeaf();
                poll.small.shrinkageMemory();
                histogramBaseTreeObjs.leaves.add(poll.small);
            } else {
                histogramBaseTreeObjs.queue.add(split(poll.small, histogramBaseTreeObjs.getDynamicSummary(poll.small.node), boostingObjs.indices, boostingObjs.data));
                histogramBaseTreeObjs.replaceWithActual(poll.small.node);
                i2 += 2;
            }
            i++;
            if (poll.big != null) {
                i2--;
                poll.big.node.copy(nodeArr[i]);
                if (nodeArr[i].isLeaf() || histogramBaseTreeObjs.leaves.size() + i2 + 2 > intValue) {
                    poll.big.node.makeLeaf();
                    poll.big.shrinkageMemory();
                    histogramBaseTreeObjs.leaves.add(poll.big);
                } else {
                    histogramBaseTreeObjs.queue.add(split(poll.big, histogramBaseTreeObjs.getDynamicSummary(poll.big.node), boostingObjs.indices, boostingObjs.data));
                    histogramBaseTreeObjs.replaceWithActual(poll.big.node);
                    i2 += 2;
                }
                i++;
            }
        }
        if (histogramBaseTreeObjs.queue.isEmpty()) {
            boostingObjs.inWeakLearner = false;
        }
        LOG.info("taskId: {}, {} end", Integer.valueOf(comContext.getTaskId()), SplitInstances.class.getSimpleName());
    }

    public static NodeInfoPair split(NodeInfoPair.NodeInfo nodeInfo, EpsilonApproQuantile.WQSummary wQSummary, int[] iArr, Data data) {
        int splitInstances = data.splitInstances(nodeInfo.node, wQSummary, iArr, nodeInfo.slice);
        int splitInstances2 = data.splitInstances(nodeInfo.node, wQSummary, iArr, nodeInfo.oob);
        nodeInfo.node.setNextNodes(new Node[]{new Node(), new Node()});
        return new NodeInfoPair(new NodeInfoPair.NodeInfo(nodeInfo.node.getNextNodes()[0], new Slice(nodeInfo.slice.start, splitInstances), new Slice(nodeInfo.oob.start, splitInstances2), nodeInfo.depth + 1, nodeInfo.baggingFeatures), new NodeInfoPair.NodeInfo(nodeInfo.node.getNextNodes()[1], new Slice(splitInstances, nodeInfo.slice.end), new Slice(splitInstances2, nodeInfo.oob.end), nodeInfo.depth + 1, null));
    }

    static {
        $assertionsDisabled = !SplitInstances.class.desiredAssertionStatus();
        LOG = LoggerFactory.getLogger(SplitInstances.class);
    }
}
