package com.alibaba.alink.operator.common.tree.seriestree;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.operator.common.tree.FeatureMeta;
import com.alibaba.alink.operator.common.tree.Node;
import com.alibaba.alink.params.classification.RandomForestTrainParams;
import com.alibaba.alink.params.shared.tree.HasSeed;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/seriestree/DecisionTree.class */
public class DecisionTree {
    private final DenseData data;
    private final Params params;
    private final Deque<Tuple2<Node, SequentialFeatureSplitter[]>> queue = new ArrayDeque();
    private final Random random;
    private static final int NUM_THREADS = 4;
    private final SequentialPartition[] threadLocals;
    private final ExecutorService executorService;
    private final ArrayList<Future<Double>> futures;

    public DecisionTree(DenseData denseData, Params params, ExecutorService executorService) {
        this.data = denseData;
        this.params = params;
        this.random = new Random(((Long) params.get(HasSeed.SEED)).longValue());
        AkPreconditions.checkNotNull(executorService);
        this.executorService = executorService;
        this.threadLocals = new SequentialPartition[NUM_THREADS];
        this.futures = new ArrayList<>(NUM_THREADS);
        for (int i = 0; i < NUM_THREADS; i++) {
            this.threadLocals[i] = new SequentialPartition(new ArrayList());
        }
    }

    public Node fit() {
        Node node = new Node();
        init(node);
        while (!this.queue.isEmpty()) {
            Tuple2<Node, SequentialFeatureSplitter[]> poll = this.queue.poll();
            SequentialFeatureSplitter fitNode = fitNode(bagging((SequentialFeatureSplitter[]) poll.f1), this.queue.size());
            fitNode.fillNode((Node) poll.f0);
            if (fitNode.canSplit()) {
                split((SequentialFeatureSplitter[]) poll.f1, (Node) poll.f0, fitNode);
            } else {
                ((Node) poll.f0).makeLeaf();
                ((Node) poll.f0).makeLeafProb();
            }
        }
        return node;
    }

    private SequentialFeatureSplitter fitNode(SequentialFeatureSplitter[] sequentialFeatureSplitterArr, int i) {
        return fitNodeMultiThread(sequentialFeatureSplitterArr, i);
    }

    private SequentialFeatureSplitter fitNodeSingleThread(SequentialFeatureSplitter[] sequentialFeatureSplitterArr, int i) {
        double d = 0.0d;
        SequentialFeatureSplitter sequentialFeatureSplitter = null;
        for (SequentialFeatureSplitter sequentialFeatureSplitter2 : sequentialFeatureSplitterArr) {
            double bestSplit = sequentialFeatureSplitter2.bestSplit(i);
            if (bestSplit > d || sequentialFeatureSplitter == null) {
                d = bestSplit;
                sequentialFeatureSplitter = sequentialFeatureSplitter2;
            }
        }
        return sequentialFeatureSplitter;
    }

    private SequentialFeatureSplitter fitNodeMultiThread(SequentialFeatureSplitter[] sequentialFeatureSplitterArr, int i) {
        int length = sequentialFeatureSplitterArr.length;
        for (int i2 = 0; i2 < NUM_THREADS && i2 < length; i2++) {
            sequentialFeatureSplitterArr[0].getPartition().resetThreadLocal(this.threadLocals[i2]);
        }
        double d = 0.0d;
        SequentialFeatureSplitter sequentialFeatureSplitter = null;
        for (int i3 = 0; i3 < length; i3 += NUM_THREADS) {
            this.futures.clear();
            for (int i4 = 0; i4 < NUM_THREADS && i3 + i4 < length; i4++) {
                SequentialFeatureSplitter sequentialFeatureSplitter2 = sequentialFeatureSplitterArr[i3 + i4];
                sequentialFeatureSplitter2.setPartition(this.threadLocals[i4]);
                this.futures.add(this.executorService.submit(() -> {
                    return Double.valueOf(sequentialFeatureSplitter2.bestSplit(i));
                }));
            }
            for (int i5 = 0; i5 < NUM_THREADS && i3 + i5 < length; i5++) {
                try {
                    double doubleValue = this.futures.get(i5).get().doubleValue();
                    if (doubleValue > d || sequentialFeatureSplitter == null) {
                        d = doubleValue;
                        sequentialFeatureSplitter = sequentialFeatureSplitterArr[i3 + i5];
                    }
                } catch (InterruptedException | ExecutionException e) {
                    throw new AkUnclassifiedErrorException("Error. ", e);
                }
            }
        }
        return sequentialFeatureSplitter;
    }

    private void split(SequentialFeatureSplitter[] sequentialFeatureSplitterArr, Node node, SequentialFeatureSplitter sequentialFeatureSplitter) {
        SequentialFeatureSplitter[][] sequentialFeatureSplitterArr2 = (SequentialFeatureSplitter[][]) sequentialFeatureSplitter.split(sequentialFeatureSplitterArr);
        Node[] nodeArr = new Node[sequentialFeatureSplitterArr2.length];
        for (int i = 0; i < sequentialFeatureSplitterArr2.length; i++) {
            nodeArr[i] = new Node();
            this.queue.add(Tuple2.of(nodeArr[i], sequentialFeatureSplitterArr2[i]));
        }
        node.setNextNodes(nodeArr);
    }

    private SequentialFeatureSplitter[] bagging(SequentialFeatureSplitter[] sequentialFeatureSplitterArr) {
        int max = Math.max(1, Math.min((int) (((Double) this.params.get(RandomForestTrainParams.FEATURE_SUBSAMPLING_RATIO)).doubleValue() * sequentialFeatureSplitterArr.length), Math.min(sequentialFeatureSplitterArr.length, ((Integer) this.params.get(RandomForestTrainParams.NUM_SUBSET_FEATURES)).intValue())));
        if (max != sequentialFeatureSplitterArr.length) {
            shuffle(sequentialFeatureSplitterArr, this.random);
        }
        return (SequentialFeatureSplitter[]) Arrays.copyOf(sequentialFeatureSplitterArr, max);
    }

    private static <T> void shuffle(T[] tArr, Random random) {
        for (int length = tArr.length - 1; length > 0; length--) {
            int nextInt = random.nextInt(length + 1);
            if (nextInt != length) {
                T t = tArr[nextInt];
                tArr[nextInt] = tArr[length];
                tArr[length] = t;
            }
        }
    }

    private SequentialPartition initSequentialPartition() {
        ArrayList arrayList = new ArrayList(this.data.m);
        for (int i = 0; i < this.data.m; i++) {
            arrayList.add(Tuple2.of(Integer.valueOf(i), Double.valueOf(this.data.weights[i])));
        }
        return new SequentialPartition(arrayList);
    }

    private SequentialFeatureSplitter[] initSplitters(SequentialPartition sequentialPartition) {
        SequentialFeatureSplitter[] sequentialFeatureSplitterArr = new SequentialFeatureSplitter[this.data.featureMetas.length];
        for (int i = 0; i < this.data.featureMetas.length; i++) {
            if (this.data.featureMetas[i].getType() == FeatureMeta.FeatureType.CATEGORICAL) {
                sequentialFeatureSplitterArr[i] = new CategoricalSplitter(this.params, this.data, this.data.featureMetas[i], sequentialPartition);
            } else {
                sequentialFeatureSplitterArr[i] = new ContinuousSplitter(this.params, this.data, this.data.featureMetas[i], sequentialPartition);
            }
        }
        return sequentialFeatureSplitterArr;
    }

    private void init(Node node) {
        this.queue.push(Tuple2.of(node, initSplitters(initSequentialPartition())));
    }
}
