package com.alibaba.alink.operator.common.similarity;

import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.distance.EuclideanDistance;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceData;
import com.alibaba.alink.operator.common.distance.FastDistanceVectorData;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Stack;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/similarity/KDTree.class */
public class KDTree implements Serializable {
    private static final long serialVersionUID = 2164916702239332666L;
    private int vectorSize;
    FastDistanceVectorData[] samples;
    FastDistance distance;
    private final int LEAF_SIZE = 40;
    private int nLevel;
    private int nodeNum;
    private TreeNode root;

    /* loaded from: input_file:com/alibaba/alink/operator/common/similarity/KDTree$TreeNode.class */
    public static class TreeNode implements Serializable {
        private static final long serialVersionUID = 2059317701009040907L;
        public int nodeIndex;
        public int startIndex;
        public int endIndex;
        public int splitDim;
        public TreeNode left;
        public TreeNode right;
        public double[] downThre;
        public double[] upThre;
        boolean isLeaf;

        public TreeNode(int i, int i2) {
            this.isLeaf = false;
            this.nodeIndex = i;
            this.startIndex = i;
            this.endIndex = i2;
            this.isLeaf = false;
        }

        public TreeNode() {
            this.isLeaf = false;
            this.nodeIndex = 0;
            this.startIndex = 0;
            this.endIndex = 0;
        }
    }

    public TreeNode getRoot() {
        return this.root;
    }

    public void setRoot(TreeNode treeNode) {
        this.root = treeNode;
    }

    public KDTree(FastDistanceVectorData[] fastDistanceVectorDataArr, int i, FastDistance fastDistance) {
        this.vectorSize = i;
        this.samples = fastDistanceVectorDataArr;
        this.distance = fastDistance;
    }

    public void buildTree() {
        this.nLevel = 1 + ((int) Math.max(Criteria.INVALID_GAIN, Math.log((this.samples.length - 1) / 40) / Math.log(2.0d)));
        this.nodeNum = ((int) Math.pow(2.0d, this.nLevel)) - 1;
        this.root = recursiveBuild(0, 0, this.samples.length);
    }

    public List<FastDistanceVectorData> rangeSearch(double d, FastDistanceVectorData fastDistanceVectorData) {
        ArrayList arrayList = new ArrayList();
        Stack stack = new Stack();
        stack.push(this.root);
        while (!stack.empty()) {
            TreeNode treeNode = (TreeNode) stack.pop();
            if (null != treeNode) {
                Tuple2<Double, Double> minMaxDistance = minMaxDistance(treeNode, fastDistanceVectorData.getVector());
                double doubleValue = ((Double) minMaxDistance.f0).doubleValue();
                double doubleValue2 = ((Double) minMaxDistance.f1).doubleValue();
                if ((this.distance instanceof EuclideanDistance) && d >= doubleValue2) {
                    arrayList.addAll(Arrays.asList(this.samples).subList(treeNode.startIndex, treeNode.endIndex));
                } else if (!(this.distance instanceof EuclideanDistance) || Math.abs(d - doubleValue) <= 1.0E-12d || d >= doubleValue) {
                    if (treeNode.isLeaf) {
                        for (int i = treeNode.startIndex; i < treeNode.endIndex; i++) {
                            if (this.distance.calc((FastDistanceData) fastDistanceVectorData, (FastDistanceData) this.samples[i]).get(0, 0) <= d) {
                                arrayList.add(this.samples[treeNode.nodeIndex]);
                            }
                        }
                    } else {
                        if (this.distance.calc((FastDistanceData) fastDistanceVectorData, (FastDistanceData) this.samples[treeNode.nodeIndex]).get(0, 0) <= d) {
                            arrayList.add(this.samples[treeNode.nodeIndex]);
                        }
                        stack.add(treeNode.left);
                        stack.add(treeNode.right);
                    }
                }
            }
        }
        return arrayList;
    }

    public Tuple2<Double, Row>[] getTopN(int i, FastDistanceVectorData fastDistanceVectorData) {
        PriorityQueue priorityQueue = new PriorityQueue(i, new Comparator<Tuple2<Double, Row>>() { // from class: com.alibaba.alink.operator.common.similarity.KDTree.1
            @Override // java.util.Comparator
            public int compare(Tuple2<Double, Row> tuple2, Tuple2<Double, Row> tuple22) {
                int compareTo = ((Double) tuple22.f0).compareTo((Double) tuple2.f0);
                return compareTo == 0 ? ((Row) tuple22.f1).equals(tuple2.f1) ? 0 : 1 : compareTo;
            }
        });
        Stack stack = new Stack();
        Stack stack2 = new Stack();
        stack.push(this.root);
        stack2.push(Double.valueOf(this.distance.calc((FastDistanceData) fastDistanceVectorData, (FastDistanceData) this.samples[this.root.nodeIndex]).get(0, 0)));
        while (!stack.empty()) {
            TreeNode treeNode = (TreeNode) stack.pop();
            double doubleValue = ((Double) stack2.pop()).doubleValue();
            if (null != treeNode) {
                if (priorityQueue.size() >= i) {
                    Tuple2 tuple2 = (Tuple2) priorityQueue.peek();
                    double doubleValue2 = ((Double) minMaxDistance(treeNode, fastDistanceVectorData.getVector()).f0).doubleValue();
                    if (!(this.distance instanceof EuclideanDistance) || Math.abs(((Double) tuple2.f0).doubleValue() - doubleValue2) <= 1.0E-12d || ((Double) tuple2.f0).doubleValue() >= doubleValue2) {
                        if (doubleValue < ((Double) tuple2.f0).doubleValue()) {
                            priorityQueue.poll();
                            tuple2.f0 = Double.valueOf(doubleValue);
                            tuple2.f1 = this.samples[treeNode.nodeIndex].getRows()[0];
                            priorityQueue.add(tuple2);
                        }
                        if (treeNode.isLeaf) {
                            for (int i2 = treeNode.startIndex + 1; i2 < treeNode.endIndex; i2++) {
                                double d = this.distance.calc((FastDistanceData) fastDistanceVectorData, (FastDistanceData) this.samples[i2]).get(0, 0);
                                if (d < ((Double) ((Tuple2) priorityQueue.peek()).f0).doubleValue()) {
                                    Tuple2 tuple22 = (Tuple2) priorityQueue.poll();
                                    tuple22.f0 = Double.valueOf(d);
                                    tuple22.f1 = this.samples[i2].getRows()[0];
                                    priorityQueue.add(tuple22);
                                }
                            }
                        }
                    }
                } else if (treeNode.isLeaf) {
                    for (int i3 = treeNode.startIndex; i3 < treeNode.endIndex; i3++) {
                        double d2 = this.distance.calc((FastDistanceData) fastDistanceVectorData, (FastDistanceData) this.samples[i3]).get(0, 0);
                        if (priorityQueue.size() < i) {
                            priorityQueue.add(Tuple2.of(Double.valueOf(d2), this.samples[i3].getRows()[0]));
                        } else if (d2 < ((Double) ((Tuple2) priorityQueue.peek()).f0).doubleValue()) {
                            Tuple2 tuple23 = (Tuple2) priorityQueue.poll();
                            tuple23.f0 = Double.valueOf(d2);
                            tuple23.f1 = this.samples[i3].getRows()[0];
                            priorityQueue.add(tuple23);
                        }
                    }
                } else {
                    priorityQueue.add(Tuple2.of(Double.valueOf(doubleValue), this.samples[treeNode.nodeIndex].getRows()[0]));
                }
                double d3 = treeNode.left == null ? Double.MAX_VALUE : this.distance.calc((FastDistanceData) fastDistanceVectorData, (FastDistanceData) this.samples[treeNode.left.nodeIndex]).get(0, 0);
                double d4 = treeNode.right == null ? Double.MAX_VALUE : this.distance.calc((FastDistanceData) fastDistanceVectorData, (FastDistanceData) this.samples[treeNode.right.nodeIndex]).get(0, 0);
                if (d3 >= d4) {
                    stack.add(treeNode.left);
                    stack.add(treeNode.right);
                    stack2.add(Double.valueOf(d3));
                    stack2.add(Double.valueOf(d4));
                } else {
                    stack.add(treeNode.right);
                    stack.add(treeNode.left);
                    stack2.add(Double.valueOf(d4));
                    stack2.add(Double.valueOf(d3));
                }
            }
        }
        Tuple2<Double, Row>[] tuple2Arr = new Tuple2[priorityQueue.size()];
        int length = tuple2Arr.length - 1;
        while (!priorityQueue.isEmpty()) {
            int i4 = length;
            length--;
            tuple2Arr[i4] = (Tuple2) priorityQueue.poll();
        }
        return tuple2Arr;
    }

    private Tuple2<Double, Double> minMaxDistance(TreeNode treeNode, Vector vector) {
        double d;
        double max;
        double d2 = 0.0d;
        double d3 = 0.0d;
        if (this.distance instanceof EuclideanDistance) {
            for (int i = 0; i < this.vectorSize; i++) {
                double d4 = vector.get(i);
                if (d4 < treeNode.downThre[i]) {
                    d2 += Math.pow(d4 - treeNode.downThre[i], 2.0d);
                    d = d3;
                    max = Math.pow(d4 - treeNode.upThre[i], 2.0d);
                } else if (d4 > treeNode.upThre[i]) {
                    d2 += Math.pow(d4 - treeNode.upThre[i], 2.0d);
                    d = d3;
                    max = Math.pow(d4 - treeNode.downThre[i], 2.0d);
                } else {
                    d = d3;
                    max = Math.max(Math.pow(d4 - treeNode.downThre[i], 2.0d), Math.pow(d4 - treeNode.upThre[i], 2.0d));
                }
                d3 = d + max;
            }
        }
        return Tuple2.of(Double.valueOf(Math.sqrt(d2)), Double.valueOf(Math.sqrt(d3)));
    }

    TreeNode recursiveBuild(int i, int i2) {
        return recursiveBuild(i, i, i2);
    }

    TreeNode recursiveBuild(int i, int i2, int i3) {
        if (i2 >= i3) {
            return null;
        }
        TreeNode treeNode = new TreeNode(i2, i3);
        findBounds(treeNode);
        if ((2 * i) + 1 > this.nodeNum) {
            treeNode.isLeaf = true;
            return treeNode;
        }
        treeNode.splitDim = pickSplitDim(i2, i3);
        treeNode.nodeIndex = split(i2, i3, treeNode.splitDim);
        treeNode.left = recursiveBuild((i * 2) + 1, treeNode.startIndex, treeNode.nodeIndex);
        treeNode.right = recursiveBuild((i * 2) + 2, treeNode.nodeIndex + 1, treeNode.endIndex);
        return treeNode;
    }

    void findBounds(TreeNode treeNode) {
        double[] dArr = new double[this.vectorSize];
        double[] dArr2 = new double[this.vectorSize];
        for (int i = 0; i < this.vectorSize; i++) {
            dArr[i] = Double.MAX_VALUE;
            dArr2[i] = Double.NEGATIVE_INFINITY;
            for (int i2 = treeNode.startIndex; i2 < treeNode.endIndex; i2++) {
                double d = this.samples[i2].getVector().get(i);
                dArr[i] = Math.min(dArr[i], d);
                dArr2[i] = Math.max(dArr2[i], d);
            }
        }
        treeNode.upThre = dArr2;
        treeNode.downThre = dArr;
    }

    int pickSplitDim(int i, int i2) {
        double d = Double.NEGATIVE_INFINITY;
        int i3 = i + 1 == i2 ? -1 : -1;
        for (int i4 = 0; i4 < this.vectorSize; i4++) {
            double d2 = Double.MAX_VALUE;
            double d3 = Double.NEGATIVE_INFINITY;
            for (int i5 = i; i5 < i2; i5++) {
                double d4 = this.samples[i5].getVector().get(i4);
                d2 = Math.min(d2, d4);
                d3 = Math.max(d3, d4);
            }
            if (d3 - d2 > d) {
                i3 = i4;
                d = d3 - d2;
            }
        }
        return i3;
    }

    int split(int i, int i2, int i3) {
        int i4 = i2 - 1;
        int i5 = i + ((i4 - i) / 2);
        while (true) {
            double d = this.samples[i].getVector().get(i3);
            FastDistanceVectorData fastDistanceVectorData = this.samples[i];
            int i6 = i;
            int i7 = i4;
            while (i6 < i7) {
                while (i6 < i7 && this.samples[i7].getVector().get(i3) >= d) {
                    i7--;
                }
                if (i6 < i7) {
                    this.samples[i6] = this.samples[i7];
                    i6++;
                }
                while (i6 < i7 && this.samples[i6].getVector().get(i3) <= d) {
                    i6++;
                }
                if (i6 < i7) {
                    this.samples[i7] = this.samples[i6];
                    i7--;
                }
            }
            this.samples[i6] = fastDistanceVectorData;
            if (i6 == i5) {
                return i5;
            }
            if (i6 < i5) {
                i = i6 + 1;
            } else {
                i4 = i6 - 1;
            }
        }
    }
}
