package com.alibaba.alink.operator.common.clustering;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.RichModelMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.clustering.BisectingKMeansTrainBatchOp;
import com.alibaba.alink.operator.common.clustering.BisectingKMeansModelData;
import com.alibaba.alink.operator.common.clustering.kmeans.KMeansUtil;
import com.alibaba.alink.operator.common.distance.ContinuousDistance;
import com.alibaba.alink.operator.common.distance.EuclideanDistance;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/BisectingKMeansModelMapper.class */
public class BisectingKMeansModelMapper extends RichModelMapper {
    private static final long serialVersionUID = 1293356859097519385L;
    private BisectingKMeansModelData modelData;
    private Tree tree;
    private int vectorColIdx;

    /* loaded from: input_file:com/alibaba/alink/operator/common/clustering/BisectingKMeansModelMapper$Tree.class */
    public static class Tree {
        TreeNode root;
        List<Long> treeNodeIds;

        public Tree(Map<Long, BisectingKMeansModelData.ClusterSummary> map) {
            this.root = new TreeNode(1L, map.get(1L).center);
            ArrayDeque arrayDeque = new ArrayDeque();
            arrayDeque.add(this.root);
            while (!arrayDeque.isEmpty()) {
                TreeNode treeNode = (TreeNode) arrayDeque.poll();
                long leftChildIndex = BisectingKMeansTrainBatchOp.leftChildIndex(treeNode.treeNodeId);
                long rightChildIndex = BisectingKMeansTrainBatchOp.rightChildIndex(treeNode.treeNodeId);
                if (map.containsKey(Long.valueOf(leftChildIndex))) {
                    TreeNode treeNode2 = new TreeNode(leftChildIndex, map.get(Long.valueOf(leftChildIndex)).center);
                    treeNode.leftChild = treeNode2;
                    arrayDeque.add(treeNode2);
                }
                if (map.containsKey(Long.valueOf(rightChildIndex))) {
                    TreeNode treeNode3 = new TreeNode(rightChildIndex, map.get(Long.valueOf(rightChildIndex)).center);
                    treeNode.rightChild = treeNode3;
                    arrayDeque.add(treeNode3);
                }
            }
            this.root.constructMiddlePlane();
            assignClusterId();
        }

        private void assignClusterId() {
            ArrayDeque arrayDeque = new ArrayDeque();
            arrayDeque.add(this.root);
            long j = 0;
            this.treeNodeIds = new ArrayList();
            while (!arrayDeque.isEmpty()) {
                TreeNode treeNode = (TreeNode) arrayDeque.poll();
                if (treeNode.isLeaf()) {
                    treeNode.clusterId = j;
                    this.treeNodeIds.add(Long.valueOf(treeNode.treeNodeId));
                    j++;
                } else {
                    if (treeNode.leftChild != null) {
                        arrayDeque.add(treeNode.leftChild);
                    }
                    if (treeNode.rightChild != null) {
                        arrayDeque.add(treeNode.rightChild);
                    }
                }
            }
        }

        public Tuple2<Long, Long> predict(Vector vector, ContinuousDistance continuousDistance) {
            return this.root.predict(vector, continuousDistance);
        }

        public List<Long> getTreeNodeIds() {
            return this.treeNodeIds;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/clustering/BisectingKMeansModelMapper$TreeNode.class */
    public static class TreeNode {
        long treeNodeId;
        DenseVector center;
        Tuple2<DenseVector, Double> middlePlane;
        TreeNode leftChild = null;
        TreeNode rightChild = null;
        long clusterId = -1;

        public TreeNode(long j, DenseVector denseVector) {
            this.treeNodeId = j;
            this.center = denseVector;
        }

        public boolean isLeaf() {
            return this.leftChild == null && this.rightChild == null;
        }

        void constructMiddlePlane() {
            if (isLeaf()) {
                return;
            }
            DenseVector mo136clone = this.rightChild.center.mo136clone();
            DenseVector mo136clone2 = this.leftChild.center.mo136clone();
            DenseVector mo136clone3 = mo136clone.mo136clone();
            BLAS.axpy(1.0d, mo136clone2, mo136clone3);
            BLAS.axpy(-1.0d, mo136clone2, mo136clone);
            BLAS.scal(0.5d, mo136clone3);
            this.middlePlane = Tuple2.of(mo136clone, Double.valueOf(BLAS.dot(mo136clone3, mo136clone)));
            if (this.leftChild != null) {
                this.leftChild.constructMiddlePlane();
            }
            if (this.rightChild != null) {
                this.rightChild.constructMiddlePlane();
            }
        }

        public Tuple2<Long, Long> predict(Vector vector, ContinuousDistance continuousDistance) {
            TreeNode treeNode;
            if (isLeaf()) {
                return Tuple2.of(Long.valueOf(this.clusterId), Long.valueOf(this.treeNodeId));
            }
            if (continuousDistance instanceof EuclideanDistance) {
                treeNode = MatVecOp.dot(vector, (Vector) this.middlePlane.f0) < ((Double) this.middlePlane.f1).doubleValue() ? this.leftChild : this.rightChild;
            } else {
                treeNode = BisectingKMeansTrainBatchOp.getClosestNode(0L, this.leftChild.center, 1L, this.rightChild.center, vector, continuousDistance) == 0 ? this.leftChild : this.rightChild;
            }
            return treeNode.predict(vector, continuousDistance);
        }
    }

    public BisectingKMeansModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected TypeInformation<?> initPredResultColType(TableSchema tableSchema) {
        return Types.LONG;
    }

    private double[] computeProbability(long j, List<Long> list) {
        double[] dArr = new double[list.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = nodeDistanceInTree(j, list.get(i).longValue());
        }
        return KMeansUtil.getProbArrayFromDistanceArray(dArr);
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Object predictResult(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        return predictResultDetail(slicedSelectedSample).f0;
    }

    @Override // com.alibaba.alink.common.mapper.RichModelMapper
    protected Tuple2<Object, String> predictResultDetail(Mapper.SlicedSelectedSample slicedSelectedSample) throws Exception {
        Vector vector = VectorUtil.getVector(slicedSelectedSample.get(this.vectorColIdx));
        if (vector.size() != this.modelData.vectorSize) {
            throw new AkIllegalDataException("Dim of predict data not equal to vectorSize of training data: " + this.modelData.vectorSize);
        }
        Tuple2<Long, Long> predict = this.tree.predict(vector, this.modelData.distanceType.getFastDistance());
        return Tuple2.of(predict.f0, VectorUtil.serialize(new DenseVector(computeProbability(((Long) predict.f1).longValue(), this.tree.treeNodeIds))));
    }

    private int level(long j) {
        int i = 0;
        while (j > 1) {
            j /= 2;
            i++;
        }
        return i;
    }

    private double nodeDistanceInTree(long j, long j2) {
        int level = level(j);
        int level2 = level(j2);
        int i = 0;
        if (level > level2) {
            while (level > level2) {
                j /= 2;
                level = level(j);
                i++;
            }
        } else if (level2 > level) {
            while (level2 > level) {
                j2 /= 2;
                level2 = level(j2);
                i++;
            }
        }
        while (j != j2) {
            j /= 2;
            j2 /= 2;
            i += 2;
        }
        return i;
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.modelData = new BisectingKMeansModelDataConverter().load(list);
        this.vectorColIdx = TableUtil.findColIndexWithAssert(super.getDataSchema().getFieldNames(), this.modelData.vectorColName);
        this.tree = new Tree(this.modelData.summaries);
    }
}
