package com.alibaba.alink.operator.common.similarity.modeldata;

import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.operator.common.distance.EuclideanDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceData;
import com.alibaba.alink.operator.common.distance.FastDistanceVectorData;
import com.alibaba.alink.operator.common.similarity.KDTree;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/similarity/modeldata/KDTreeModelData.class */
public class KDTreeModelData extends NearestNeighborModelData {
    private static final long serialVersionUID = 6293716822361507135L;
    private static final EuclideanDistance distance = new EuclideanDistance();
    private final List<KDTree> treeList;

    public KDTreeModelData(List<KDTree> list) {
        this.treeList = list;
        this.comparator = Comparator.comparingDouble(tuple2 -> {
            return -((Double) tuple2.f0).doubleValue();
        });
    }

    @Override // com.alibaba.alink.operator.common.similarity.modeldata.NearestNeighborModelData
    protected Integer getLength() {
        return Integer.valueOf(this.treeList.size());
    }

    @Override // com.alibaba.alink.operator.common.similarity.modeldata.NearestNeighborModelData
    protected Object prepareSample(Object obj) {
        return distance.prepareVectorData(Tuple2.of(VectorUtil.getVector(obj), (Object) null));
    }

    @Override // com.alibaba.alink.operator.common.similarity.modeldata.NearestNeighborModelData
    protected ArrayList<Tuple2<Double, Object>> computeDistiance(Object obj, Integer num, Integer num2, Tuple2<Double, Object> tuple2) {
        KDTree kDTree = this.treeList.get(num.intValue());
        ArrayList<Tuple2<Double, Object>> arrayList = new ArrayList<>();
        if (null != num2) {
            Tuple2<Double, Row>[] topN = kDTree.getTopN(num2.intValue(), (FastDistanceVectorData) obj);
            for (int i = 0; i < topN.length; i++) {
                Tuple2<Double, Object> of = Tuple2.of(topN[i].f0, ((Row) topN[i].f1).getField(0));
                if (null == tuple2 || tuple2.f0 == null || getQueueComparator().compare(tuple2, of) <= 0) {
                    arrayList.add(of);
                }
            }
        } else {
            for (FastDistanceVectorData fastDistanceVectorData : kDTree.rangeSearch(((Double) tuple2.f0).doubleValue(), (FastDistanceVectorData) obj)) {
                arrayList.add(Tuple2.of(Double.valueOf(distance.calc((FastDistanceData) fastDistanceVectorData, (FastDistanceData) obj).get(0, 0)), fastDistanceVectorData.getRows()[0].getField(0)));
            }
        }
        return arrayList;
    }
}
