package com.alibaba.alink.operator.common.clustering.kmeans;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.params.clustering.KMeansPredictParams;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmeans/KMeansModelMapper.class */
public class KMeansModelMapper extends ModelMapper {
    private static final long serialVersionUID = -7232694013661020935L;
    private KMeansPredictModelData modelData;
    private int[] colIdx;
    private FastDistance distance;
    private final boolean isPredDetail;
    private final boolean isPredDistance;

    public KMeansModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.isPredDetail = params.contains(KMeansPredictParams.PREDICTION_DETAIL_COL);
        this.isPredDistance = params.contains(KMeansPredictParams.PREDICTION_DISTANCE_COL);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        Vector vector;
        if (this.colIdx.length > 1) {
            vector = new DenseVector(2);
            vector.set(0, ((Number) slicedSelectedSample.get(this.colIdx[0])).doubleValue());
            vector.set(1, ((Number) slicedSelectedSample.get(this.colIdx[1])).doubleValue());
        } else {
            vector = VectorUtil.getVector(slicedSelectedSample.get(this.colIdx[0]));
        }
        if (null == vector) {
            slicedResult.set(0, null);
            if (!this.isPredDetail) {
                if (this.isPredDistance) {
                    slicedResult.set(1, null);
                    return;
                }
                return;
            } else {
                slicedResult.set(1, null);
                if (this.isPredDistance) {
                    slicedResult.set(2, null);
                    return;
                }
                return;
            }
        }
        double[] clusterDistances = KMeansUtil.getClusterDistances(this.distance.prepareVectorData(Tuple2.of(vector, (Object) null)), this.modelData.centroids, this.distance, new DenseMatrix(this.modelData.params.k, 1));
        int minPointIndex = KMeansUtil.getMinPointIndex(clusterDistances, this.modelData.params.k);
        slicedResult.set(0, Long.valueOf(minPointIndex));
        if (!this.isPredDetail) {
            if (this.isPredDistance) {
                slicedResult.set(1, Double.valueOf(clusterDistances[minPointIndex]));
                return;
            }
            return;
        }
        double[] probArrayFromDistanceArray = KMeansUtil.getProbArrayFromDistanceArray(clusterDistances);
        DenseVector denseVector = new DenseVector(probArrayFromDistanceArray.length);
        for (int i = 0; i < this.modelData.params.k; i++) {
            denseVector.set((int) this.modelData.getClusterId(i), probArrayFromDistanceArray[i]);
        }
        slicedResult.set(1, denseVector.toString());
        if (this.isPredDistance) {
            slicedResult.set(2, Double.valueOf(clusterDistances[minPointIndex]));
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        String[] strArr = (String[]) params.get(KMeansPredictParams.RESERVED_COLS);
        String str = (String) params.get(KMeansPredictParams.PREDICTION_COL);
        boolean contains = params.contains(KMeansPredictParams.PREDICTION_DETAIL_COL);
        boolean contains2 = params.contains(KMeansPredictParams.PREDICTION_DISTANCE_COL);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add(str);
        arrayList2.add(Types.LONG);
        if (contains) {
            arrayList.add(params.get(KMeansPredictParams.PREDICTION_DETAIL_COL));
            arrayList2.add(Types.STRING);
        }
        if (contains2) {
            arrayList.add(params.get(KMeansPredictParams.PREDICTION_DISTANCE_COL));
            arrayList2.add(Types.DOUBLE);
        }
        return Tuple4.of(tableSchema2.getFieldNames(), arrayList.toArray(new String[0]), arrayList2.toArray(new TypeInformation[0]), strArr);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.modelData = new KMeansModelDataConverter().load(list);
        this.distance = this.modelData.params.distanceType.getFastDistance();
        this.colIdx = KMeansUtil.getKmeansPredictColIdxs(this.modelData.params, getDataSchema().getFieldNames());
    }
}
