package com.alibaba.alink.operator.common.tree.paralleltree;

import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.operator.common.tree.Node;
import com.alibaba.alink.operator.common.tree.paralleltree.NodeInfoPair;
import com.alibaba.alink.params.classification.GbdtTrainParams;
import com.alibaba.alink.params.classification.RandomForestTrainParams;
import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/paralleltree/TreeObj.class */
public abstract class TreeObj<LABELARRAY> implements Serializable {
    private static final Logger LOG = LoggerFactory.getLogger(TreeObj.class);
    public static final ParamInfo<Integer> N_LOCAL_ROW = ParamInfoFactory.createParamInfo("nLocalRow", Integer.class).setDescription("n local row").setRequired().build();
    public static final ParamInfo<Integer> TASK_ID = ParamInfoFactory.createParamInfo("taskId", Integer.class).setDescription("task id").setRequired().build();
    public static final ParamInfo<Integer> NUM_OF_SUBTASKS = ParamInfoFactory.createParamInfo("numOfSubTasks", Integer.class).setDescription("numOfSubTasks").setRequired().build();
    private static final long serialVersionUID = 790791300919497663L;
    protected int maxHistBufferSize;
    protected int maxLoopBufferSize;
    protected int taskId;
    protected int numOfSubTasks;
    protected int[] features;
    protected int[] partitions;
    protected double[] hist;
    protected double[] minusHist;
    protected int[] randomShuffleBuf;
    protected BufferPool parentHistPool;
    protected Deque<NodeInfoPair> queue;
    protected NodeInfoPair[] loopBuffer;
    protected int loopBufferSize;
    protected Params params;
    protected int nFeatureCol;
    protected int numLocalRow;
    protected int numLocalBaggingRow;
    protected int nBin;
    protected double subSamplingRatio;
    protected int numTrees;
    protected FeatureMeta[] featureMetas;
    protected FeatureMeta labelMeta;
    protected QuantileDiscretizerModelDataConverter quantileDiscretizerModel;
    protected Random randomSample;
    protected Random randomFeature;
    protected final int maxDepth;
    protected final int minSamplesPerLeaf;
    public final ParamInfo<Long> SEED = ParamInfoFactory.createParamInfo("seed", Long.class).setDescription("seed").setHasDefaultValue(0L).build();
    protected List<Node> roots = new LinkedList();
    protected boolean rootBagging = false;

    public TreeObj(Params params, QuantileDiscretizerModelDataConverter quantileDiscretizerModelDataConverter, FeatureMeta[] featureMetaArr, FeatureMeta featureMeta) {
        this.params = params;
        this.featureMetas = featureMetaArr;
        this.labelMeta = featureMeta;
        this.quantileDiscretizerModel = quantileDiscretizerModelDataConverter;
        this.numTrees = ((Integer) params.get(RandomForestTrainParams.NUM_TREES)).intValue();
        this.nFeatureCol = ((String[]) params.get(RandomForestTrainParams.FEATURE_COLS)).length;
        this.numLocalRow = ((Integer) params.get(N_LOCAL_ROW)).intValue();
        this.nBin = ((Integer) params.get(GbdtTrainParams.MAX_BINS)).intValue();
        this.taskId = ((Integer) params.get(TASK_ID)).intValue();
        this.subSamplingRatio = ((Double) params.get(RandomForestTrainParams.SUBSAMPLING_RATIO)).doubleValue();
        this.numOfSubTasks = ((Integer) params.get(NUM_OF_SUBTASKS)).intValue();
        this.maxHistBufferSize = ((((Integer) params.get(RandomForestTrainParams.MAX_MEMORY_IN_MB)).intValue() * 1024) * 1024) / 8;
        this.numLocalBaggingRow = (int) Math.min(this.numLocalRow, Math.ceil(this.numLocalRow * this.subSamplingRatio));
        this.randomSample = new Random(((Long) params.get(this.SEED)).longValue());
        this.randomFeature = new Random(((Long) params.get(this.SEED)).longValue());
        this.maxDepth = ((Integer) params.get(RandomForestTrainParams.MAX_DEPTH)).intValue();
        this.minSamplesPerLeaf = ((Integer) params.get(RandomForestTrainParams.MIN_SAMPLES_PER_LEAF)).intValue();
    }

    public static final Node ofNode() {
        return new Node();
    }

    public static final Node ofNode(NodeInfoPair nodeInfoPair, boolean z) {
        Node ofNode;
        if (nodeInfoPair.root()) {
            ofNode = nodeInfoPair.parentNode;
        } else {
            ofNode = ofNode();
            if (nodeInfoPair.parentNode.getNextNodes() == null) {
                nodeInfoPair.parentNode.setNextNodes(new Node[2]);
            }
            if (z) {
                nodeInfoPair.parentNode.getNextNodes()[0] = ofNode;
            } else {
                nodeInfoPair.parentNode.getNextNodes()[1] = ofNode;
            }
        }
        return ofNode;
    }

    public void setFeatures(int[] iArr) {
        this.features = iArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void initBuffer() {
        this.maxLoopBufferSize = this.maxHistBufferSize / lenStatUnit();
        this.queue = new ArrayDeque(this.maxLoopBufferSize * 2);
        this.loopBuffer = new NodeInfoPair[this.maxLoopBufferSize];
        this.loopBufferSize = 0;
        this.minusHist = new double[this.maxHistBufferSize * 2];
        if (!useStatPair()) {
            this.parentHistPool = new BufferPool(lenPerStat());
        }
        initPartitions();
        this.randomShuffleBuf = new int[this.nFeatureCol];
        for (int i = 0; i < this.nFeatureCol; i++) {
            this.randomShuffleBuf[i] = i;
        }
    }

    public void setHist(double[] dArr) {
        this.hist = dArr;
    }

    public int getMaxHistBufferSize() {
        return this.maxHistBufferSize;
    }

    protected void initPartitions() {
        int[] iArr = new int[this.numLocalRow];
        for (int i = 0; i < this.numLocalRow; i++) {
            iArr[i] = i;
        }
        this.partitions = new int[this.numLocalBaggingRow * this.numTrees];
        for (int i2 = 0; i2 < this.numTrees; i2++) {
            shuffle(iArr, this.randomSample);
            System.arraycopy(iArr, 0, this.partitions, i2 * this.numLocalBaggingRow, this.numLocalBaggingRow);
        }
    }

    public final NodeInfoPair ofNodeInfoPair(Node node, int i, NodeInfoPair nodeInfoPair) {
        NodeInfoPair nodeInfoPair2 = new NodeInfoPair();
        if (i < 0) {
            nodeInfoPair2.parentQueueId = i;
        } else if (!useStatPair()) {
            int nextValidId = this.parentHistPool.nextValidId();
            System.arraycopy(this.minusHist, i * lenStatUnit(), this.parentHistPool.get(nextValidId), 0, lenStatUnit());
            nodeInfoPair2.parentQueueId = nextValidId;
        }
        if (nodeInfoPair == null) {
            nodeInfoPair2.depth = 1;
        } else {
            nodeInfoPair2.depth = nodeInfoPair.depth + 1;
            nodeInfoPair2.baggingFeatures = nodeInfoPair.baggingFeatures;
        }
        nodeInfoPair2.parentNode = node;
        return nodeInfoPair2;
    }

    public abstract int lenPerStat();

    public int lenStatUnit() {
        return useStatPair() ? lenPerStat() * 2 : lenPerStat();
    }

    public final double[] hist() {
        return this.hist;
    }

    public final int histLen() {
        return this.loopBufferSize * lenStatUnit();
    }

    public abstract void setLabels(LABELARRAY labelarray);

    public abstract void stat(int i, int i2, int i3, int i4, int i5);

    public final Deque<NodeInfoPair> getQueue() {
        return this.queue;
    }

    public final void initialRoot() {
        for (int i = 0; i < this.numTrees; i++) {
            Node ofNode = ofNode();
            addNodeInfoPair(ofNodeInfoPair(ofNode, -1, null).initialRoot(i * this.numLocalBaggingRow, (i + 1) * this.numLocalBaggingRow));
            this.roots.add(ofNode);
        }
    }

    public final void addNodeInfoPair(NodeInfoPair nodeInfoPair) {
        this.queue.addLast(nodeInfoPair);
    }

    public final void determineLoopNode() {
        int i = 0;
        this.loopBufferSize = 0;
        while (this.loopBufferSize + 1 <= this.maxLoopBufferSize && this.queue.peekFirst() != null) {
            i += lenStatUnit();
            if (i > this.maxHistBufferSize) {
                return;
            }
            this.loopBuffer[this.loopBufferSize] = this.queue.pollFirst();
            this.loopBufferSize++;
        }
    }

    public final void initialLoop() {
        int baggingFeatureCount = baggingFeatureCount();
        if (baggingFeatureCount != this.nFeatureCol) {
            for (int i = 0; i < this.loopBufferSize; i++) {
                if (this.loopBuffer[i].baggingFeatures == null || !this.rootBagging) {
                    this.loopBuffer[i].baggingFeatures = new int[baggingFeatureCount];
                    shuffle(this.randomShuffleBuf, this.randomFeature);
                    System.arraycopy(this.randomShuffleBuf, 0, this.loopBuffer[i].baggingFeatures, 0, baggingFeatureCount);
                    LOG.info("taskId: {}, loopBuffer: {}, randomFeatureEnd", Integer.valueOf(this.taskId), JsonConverter.gson.toJson(this.loopBuffer[i]));
                }
            }
        }
    }

    public static final void shuffle(int[] iArr, Random random) {
        for (int length = iArr.length - 1; length > 0; length--) {
            int nextInt = random.nextInt(length + 1);
            if (nextInt != length) {
                int i = iArr[nextInt];
                iArr[nextInt] = iArr[length];
                iArr[length] = i;
            }
        }
    }

    public final int stat() {
        for (int i = 0; i < this.loopBufferSize; i++) {
            int lenStatUnit = i * lenStatUnit();
            int lenPerStat = lenStatUnit + lenPerStat();
            stat(i, this.loopBuffer[i].small.start, this.loopBuffer[i].small.end, lenStatUnit, lenPerStat);
            if (useStatPair() && this.loopBuffer[i].big != null) {
                stat(i, this.loopBuffer[i].big.start, this.loopBuffer[i].big.end, lenStatUnit + lenPerStat(), lenPerStat + lenPerStat());
            }
        }
        return this.loopBufferSize;
    }

    public abstract void bestSplit(Node node, int i, NodeInfoPair nodeInfoPair) throws Exception;

    public final void bestSplit() throws Exception {
        double[] dArr = this.hist;
        if (useStatPair()) {
            System.arraycopy(dArr, 0, this.minusHist, 0, histLen());
        }
        int lenStatUnit = lenStatUnit();
        int i = 0;
        for (int i2 = 0; i2 < this.loopBufferSize; i2++) {
            int i3 = i2 * lenStatUnit;
            if (!useStatPair()) {
                System.arraycopy(dArr, i3, this.minusHist, i * lenStatUnit, lenStatUnit);
            }
            Node ofNode = ofNode(this.loopBuffer[i2], true);
            bestSplit(ofNode, i, this.loopBuffer[i2]);
            split(ofNode, this.loopBuffer[i2], true, i);
            replaceWithActual(ofNode);
            if (this.loopBuffer[i2].big != null) {
                if (!useStatPair()) {
                    double[] dArr2 = this.parentHistPool.get(this.loopBuffer[i2].parentQueueId);
                    int i4 = (i + 1) * lenStatUnit;
                    for (int i5 = 0; i5 < lenStatUnit; i5++) {
                        this.minusHist[i4 + i5] = dArr2[i5] - dArr[i3 + i5];
                    }
                    this.parentHistPool.release(this.loopBuffer[i2].parentQueueId);
                }
                i++;
                Node ofNode2 = ofNode(this.loopBuffer[i2], false);
                bestSplit(ofNode2, i, this.loopBuffer[i2]);
                split(ofNode2, this.loopBuffer[i2], false, i);
                replaceWithActual(ofNode2);
            } else if (useStatPair()) {
                i++;
            }
            i++;
        }
    }

    public final void split(Node node, NodeInfoPair nodeInfoPair, boolean z, int i) {
        if (node.isLeaf()) {
            node.makeLeafProb();
            return;
        }
        NodeInfoPair.Partition partition = z ? nodeInfoPair.small : nodeInfoPair.big;
        NodeInfoPair ofNodeInfoPair = ofNodeInfoPair(node, i, nodeInfoPair);
        split(node, partition, ofNodeInfoPair);
        addNodeInfoPair(ofNodeInfoPair);
    }

    public final void replaceWithActual(Node node) {
        if (node.isLeaf() || node.getCategoricalSplit() != null) {
            return;
        }
        node.setContinuousSplit(this.quantileDiscretizerModel.getFeatureValue(this.featureMetas[node.getFeatureIndex()].getName(), (int) node.getContinuousSplit()));
    }

    public final boolean left(int i, int i2, Node node) {
        int i3 = this.features[i + this.partitions[i2]];
        return node.getCategoricalSplit() == null ? ((double) i3) <= node.getContinuousSplit() : node.getCategoricalSplit()[i3] == 0;
    }

    public final void split(Node node, NodeInfoPair.Partition partition, NodeInfoPair nodeInfoPair) {
        if (node.isLeaf()) {
            return;
        }
        int featureIndex = node.getFeatureIndex() * this.numLocalRow;
        int i = partition.start;
        int i2 = partition.end - 1;
        while (i <= i2) {
            while (i <= i2 && left(featureIndex, i, node)) {
                i++;
            }
            while (i <= i2 && !left(featureIndex, i2, node)) {
                i2--;
            }
            if (i < i2) {
                int i3 = this.partitions[i];
                this.partitions[i] = this.partitions[i2];
                this.partitions[i2] = i3;
            }
        }
        nodeInfoPair.small = new NodeInfoPair.Partition();
        nodeInfoPair.big = new NodeInfoPair.Partition();
        nodeInfoPair.small.start = partition.start;
        nodeInfoPair.small.end = i;
        nodeInfoPair.big.start = i;
        nodeInfoPair.big.end = partition.end;
    }

    public final boolean terminationCriterion() {
        return this.queue.isEmpty();
    }

    public final List<Node> getRoots() {
        return this.roots;
    }

    private boolean useStatPair() {
        return (baggingFeatureCount() == this.nFeatureCol || this.rootBagging) ? false : true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int baggingFeatureCount() {
        return Math.max(1, Math.min((int) (((Double) this.params.get(RandomForestTrainParams.FEATURE_SUBSAMPLING_RATIO)).doubleValue() * this.nFeatureCol), this.nFeatureCol));
    }
}
