package com.alibaba.alink.operator.local.similarity;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.RowUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.distance.CosineDistance;
import com.alibaba.alink.operator.common.distance.EuclideanDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceMatrixData;
import com.alibaba.alink.operator.common.recommendation.KObjectUtil;
import com.alibaba.alink.operator.common.similarity.NearestNeighborsMapper;
import com.alibaba.alink.operator.common.similarity.dataConverter.VectorModelDataConverter;
import com.alibaba.alink.operator.common.similarity.modeldata.VectorModelData;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.local.AlinkLocalSession;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.operator.local.source.AkSourceLocalOp;
import com.alibaba.alink.operator.local.utils.ModelMapLocalOp;
import com.alibaba.alink.operator.local.utils.TopK;
import com.alibaba.alink.params.shared.HasModelFilePath;
import com.alibaba.alink.params.shared.HasNumThreads;
import com.alibaba.alink.params.similarity.NearestNeighborPredictParams;
import com.alibaba.flink.ml.tf2.shaded.com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

@ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES})
@InputPorts(values = {@PortSpec(value = PortType.MODEL, suggestions = {VectorNearestNeighborTrainLocalOp.class}), @PortSpec(PortType.DATA)})
@NameCn("向量最近邻预测")
/* loaded from: input_file:com/alibaba/alink/operator/local/similarity/VectorNearestNeighborPredictLocalOp.class */
public class VectorNearestNeighborPredictLocalOp extends LocalOperator<VectorNearestNeighborPredictLocalOp> implements NearestNeighborPredictParams<VectorNearestNeighborPredictLocalOp>, HasModelFilePath<VectorNearestNeighborPredictLocalOp> {

    /* JADX INFO: Access modifiers changed from: package-private */
    @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES})
    @InputPorts(values = {@PortSpec(value = PortType.MODEL, suggestions = {VectorNearestNeighborTrainLocalOp.class}), @PortSpec(PortType.DATA)})
    @NameCn("向量最近邻预测")
    /* loaded from: input_file:com/alibaba/alink/operator/local/similarity/VectorNearestNeighborPredictLocalOp$SubVectorNearestNeighborPredictLocalOp.class */
    public static class SubVectorNearestNeighborPredictLocalOp extends ModelMapLocalOp<SubVectorNearestNeighborPredictLocalOp> implements NearestNeighborPredictParams<SubVectorNearestNeighborPredictLocalOp> {
        public SubVectorNearestNeighborPredictLocalOp() {
            this(new Params());
        }

        public SubVectorNearestNeighborPredictLocalOp(Params params) {
            super(NearestNeighborsMapper::new, params);
        }
    }

    public VectorNearestNeighborPredictLocalOp() {
        this(new Params());
    }

    public VectorNearestNeighborPredictLocalOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.local.LocalOperator
    public VectorNearestNeighborPredictLocalOp linkFrom(LocalOperator<?>... localOperatorArr) {
        checkMinOpSize(1, localOperatorArr);
        Integer topN = getTopN();
        Double radius = getRadius();
        AkPreconditions.checkArgument((topN == null && radius == null) ? false : true, "Must give topN or radius!");
        LocalOperator<?> localOperator = localOperatorArr.length == 2 ? localOperatorArr[0] : null;
        LocalOperator<?> localOperator2 = localOperatorArr.length == 2 ? localOperatorArr[1] : localOperatorArr[0];
        if (localOperator == null && getParams().get(HasModelFilePath.MODEL_FILE_PATH) != null) {
            localOperator = new AkSourceLocalOp().setFilePath(getModelFilePath());
        } else if (localOperator == null) {
            throw new AkIllegalOperatorParameterException("One of model or modelFilePath should be set.");
        }
        try {
            VectorModelDataConverter vectorModelDataConverter = new VectorModelDataConverter();
            vectorModelDataConverter.setIdType(localOperator.getColTypes()[localOperator.getColNames().length - 1]);
            VectorModelData load = vectorModelDataConverter.load(localOperator.getOutputTable().getRows());
            if (!(load.fastDistance instanceof EuclideanDistance) && !(load.fastDistance instanceof CosineDistance)) {
                setOutputTable(new SubVectorNearestNeighborPredictLocalOp(getParams()).linkFrom(localOperatorArr).getOutputTable());
                return this;
            }
            boolean z = load.fastDistance instanceof CosineDistance;
            load.dictData.get(0).getRows();
            boolean z2 = true;
            int i = 0;
            int i2 = 0;
            for (int i3 = 0; i3 < load.dictData.size(); i3++) {
                if (load.dictData.get(i3) instanceof FastDistanceMatrixData) {
                    i += ((FastDistanceMatrixData) load.dictData.get(i3)).getVectors().numCols();
                    i2 = ((FastDistanceMatrixData) load.dictData.get(i3)).getVectors().numRows();
                } else {
                    z2 = false;
                }
            }
            int i4 = i;
            int i5 = i2;
            if (!z2) {
                throw new AkIllegalDataException("Data is not dense vector.");
            }
            DenseMatrix denseMatrix = new DenseMatrix(i5, i4);
            double[] data = denseMatrix.getData();
            double[] dArr = new double[i4];
            Object[] objArr = new Object[i4];
            int i6 = 0;
            for (int i7 = 0; i7 < load.dictData.size(); i7++) {
                FastDistanceMatrixData fastDistanceMatrixData = (FastDistanceMatrixData) load.dictData.get(i7);
                int length = fastDistanceMatrixData.getRows().length;
                System.arraycopy(fastDistanceMatrixData.getVectors().getData(), 0, data, i6 * i5, length * i5);
                if (!z) {
                    System.arraycopy(fastDistanceMatrixData.getLabel().getData(), 0, dArr, i6, length);
                }
                Row[] rows = fastDistanceMatrixData.getRows();
                for (int i8 = 0; i8 < length; i8++) {
                    objArr[i6 + i8] = rows[i8].getField(0);
                }
                i6 += length;
            }
            MTable outputTable = localOperator2.getOutputTable();
            int numRow = outputTable.getNumRow();
            int findColIndex = TableUtil.findColIndex(outputTable.getSchema(), getSelectedCol());
            Row[] rowArr = new Row[numRow];
            AlinkLocalSession.TaskRunner taskRunner = new AlinkLocalSession.TaskRunner();
            int defaultNumThreads = LocalOperator.getDefaultNumThreads();
            if (getParams().contains(HasNumThreads.NUM_THREADS)) {
                defaultNumThreads = ((Integer) getParams().get(HasNumThreads.NUM_THREADS)).intValue();
            }
            for (int i9 = 0; i9 < defaultNumThreads; i9++) {
                int startPos = (int) AlinkLocalSession.DISTRIBUTOR.startPos(i9, defaultNumThreads, numRow);
                int localRowCnt = startPos + ((int) AlinkLocalSession.DISTRIBUTOR.localRowCnt(i9, defaultNumThreads, numRow));
                taskRunner.submit(() -> {
                    for (int i10 = startPos; i10 < localRowCnt; i10 += 256) {
                        int min = Math.min(localRowCnt - i10, 256);
                        DenseMatrix denseMatrix2 = new DenseMatrix(i5, min);
                        double[] dArr2 = new double[min];
                        for (int i11 = 0; i11 < min; i11++) {
                            DenseVector denseVector = VectorUtil.getDenseVector(outputTable.getRow(i11 + i10).getField(findColIndex));
                            if (z) {
                                dArr2[i11] = denseVector.normL2();
                                denseVector.scaleEqual(1.0d / dArr2[i11]);
                            } else {
                                dArr2[i11] = denseVector.normL2Square();
                            }
                            System.arraycopy(denseVector.getData(), 0, denseMatrix2.getData(), i11 * i5, i5);
                        }
                        DenseMatrix denseMatrix3 = new DenseMatrix(i4, min);
                        if (z) {
                            Arrays.fill(denseMatrix3.getData(), 1.0d);
                            BLAS.gemm(-1.0d, denseMatrix, true, denseMatrix2, false, 1.0d, denseMatrix3);
                        } else {
                            BLAS.gemm(-2.0d, denseMatrix, true, denseMatrix2, false, Criteria.INVALID_GAIN, denseMatrix3);
                        }
                        double[] data2 = denseMatrix3.getData();
                        if (!z) {
                            for (int i12 = 0; i12 < min; i12++) {
                                int i13 = i12 * i4;
                                for (int i14 = 0; i14 < i4; i14++) {
                                    data2[i13 + i14] = Math.sqrt(Math.abs(data2[i13 + i14] + dArr[i14] + dArr2[i12]));
                                }
                            }
                        }
                        int[] iArr = new int[i4];
                        for (int i15 = 0; i15 < min; i15++) {
                            int i16 = i15 * i4;
                            for (int i17 = 0; i17 < i4; i17++) {
                                iArr[i17] = i16 + i17;
                            }
                            TopK.heapMinTopK(iArr, data2, topN.intValue(), 0, i4);
                            ArrayList arrayList = new ArrayList();
                            ArrayList arrayList2 = new ArrayList();
                            for (int i18 = 0; i18 < Math.min(topN.intValue(), i4); i18++) {
                                double d = data2[iArr[i18]];
                                if (null == radius || d <= radius.doubleValue()) {
                                    arrayList.add(objArr[iArr[i18] - i16]);
                                    arrayList2.add(Double.valueOf(d));
                                }
                            }
                            rowArr[i10 + i15] = RowUtil.merge(localOperator2.getOutputTable().getRow(i10 + i15), KObjectUtil.serializeRecomm("ID", arrayList, ImmutableMap.of("METRIC", arrayList2)));
                        }
                        if (AlinkGlobalConfiguration.isPrintProcessInfo() && (i10 - startPos) % 512 == 0) {
                            System.out.printf("one thread predict %d vec\n", Integer.valueOf(i10 - startPos));
                        }
                    }
                });
            }
            taskRunner.join();
            setOutputTable(new MTable(rowArr, TableUtil.schema2SchemaStr(localOperator2.getSchema()) + ", " + getOutputCol() + " string"));
            return this;
        } catch (ExceptionWithErrorCode e) {
            throw e;
        } catch (Exception e2) {
            throw new AkUnclassifiedErrorException("Error. ", e2);
        }
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public /* bridge */ /* synthetic */ VectorNearestNeighborPredictLocalOp linkFrom(LocalOperator[] localOperatorArr) {
        return linkFrom((LocalOperator<?>[]) localOperatorArr);
    }
}
