package com.alibaba.alink.operator.common.classification;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.similarity.NearestNeighborsMapper;
import com.alibaba.alink.params.classification.KnnPredictParams;
import com.alibaba.alink.params.classification.KnnTrainParams;
import com.alibaba.alink.params.similarity.NearestNeighborPredictParams;
import java.lang.reflect.Type;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/classification/KnnMapper.class */
public class KnnMapper extends ModelMapper {
    private static final long serialVersionUID = -6357517568280870848L;
    private NearestNeighborsMapper mapper;
    private final boolean isPredDetail;
    private int[] selectedIndices;
    private int selectIndex;
    private final Type idType;

    public KnnMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        params.set((ParamInfo<ParamInfo<String>>) NearestNeighborPredictParams.SELECTED_COL, (ParamInfo<String>) tableSchema2.getFieldNames()[0]);
        params.set((ParamInfo<ParamInfo<Integer>>) NearestNeighborPredictParams.TOP_N, (ParamInfo<Integer>) params.get(KnnPredictParams.K));
        this.mapper = new NearestNeighborsMapper(tableSchema, tableSchema2, params);
        this.isPredDetail = params.contains(KnnPredictParams.PREDICTION_DETAIL_COL);
        this.idType = this.mapper.getIdType().getTypeClass();
    }

    private Tuple2<Object, String> getKnn(Tuple2<List<Object>, List<Object>> tuple2) {
        double size = 1.0d / ((List) tuple2.f0).size();
        HashMap hashMap = new HashMap(0);
        Iterator it = ((List) tuple2.f0).iterator();
        while (it.hasNext()) {
            hashMap.merge(it.next(), Double.valueOf(size), (d, d2) -> {
                return Double.valueOf(d.doubleValue() + d2.doubleValue());
            });
        }
        double d3 = 0.0d;
        Object obj = null;
        for (Map.Entry entry : hashMap.entrySet()) {
            if (((Double) entry.getValue()).doubleValue() > d3) {
                d3 = ((Double) entry.getValue()).doubleValue();
                obj = entry.getKey();
            }
        }
        return Tuple2.of(obj, JsonConverter.toJson(hashMap));
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.mapper.loadModel(list);
        Params meta = this.mapper.getMeta();
        String[] strArr = (String[]) meta.get(KnnTrainParams.FEATURE_COLS);
        if (null != strArr) {
            this.selectedIndices = TableUtil.findColIndicesWithAssertAndHint(getDataSchema(), strArr);
        } else {
            this.selectIndex = TableUtil.findColIndexWithAssertAndHint(getDataSchema(), (String) meta.get(KnnTrainParams.VECTOR_COL));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        Vector vector;
        if (null != this.selectedIndices) {
            vector = new DenseVector(this.selectedIndices.length);
            for (int i = 0; i < this.selectedIndices.length; i++) {
                AkPreconditions.checkNotNull(slicedSelectedSample.get(this.selectedIndices[i]), "There is NULL in featureCols!");
                vector.set(i, ((Number) slicedSelectedSample.get(this.selectedIndices[i])).doubleValue());
            }
        } else {
            vector = VectorUtil.getVector(slicedSelectedSample.get(this.selectIndex));
        }
        Tuple2<Object, String> knn = getKnn(NearestNeighborsMapper.extractKObject((String) this.mapper.predictResult(vector), this.idType));
        slicedResult.set(0, knn.f0);
        if (this.isPredDetail) {
            slicedResult.set(1, knn.f1);
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        boolean contains = params.contains(KnnPredictParams.PREDICTION_DETAIL_COL);
        TypeInformation typeInformation = tableSchema.getFieldTypes()[tableSchema.getFieldNames().length - 1];
        return Tuple4.of(tableSchema2.getFieldNames(), contains ? new String[]{(String) params.get(KnnPredictParams.PREDICTION_COL), (String) params.get(KnnPredictParams.PREDICTION_DETAIL_COL)} : new String[]{(String) params.get(KnnPredictParams.PREDICTION_COL)}, contains ? new TypeInformation[]{typeInformation, Types.STRING} : new TypeInformation[]{typeInformation}, params.get(KnnPredictParams.RESERVED_COLS));
    }
}
