package com.alibaba.alink.operator.common.clustering.kmodes;

import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.distance.OneZeroDistance;
import com.alibaba.alink.params.clustering.ClusteringPredictParams;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/kmodes/KModesModelMapper.class */
public class KModesModelMapper extends ModelMapper {
    private static final long serialVersionUID = 1212257106447281392L;
    private final boolean isPredDetail;
    private KModesModelData modelData;
    private int[] colIdx;
    private final OneZeroDistance distance;

    public KModesModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.distance = new OneZeroDistance();
        this.isPredDetail = params.contains(ClusteringPredictParams.PREDICTION_DETAIL_COL);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.modelData = new KModesModel().load(list);
        this.colIdx = new int[this.modelData.featureColNames.length];
        for (int i = 0; i < this.modelData.featureColNames.length; i++) {
            this.colIdx[i] = TableUtil.findColIndexWithAssert(getDataSchema().getFieldNames(), this.modelData.featureColNames[i]);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        String[] strArr = new String[this.colIdx.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = String.valueOf(slicedSelectedSample.get(this.colIdx[i]));
        }
        Tuple2<Long, Double> cluster = getCluster(this.modelData.centroids, strArr, this.distance);
        slicedResult.set(0, cluster.f0);
        if (this.isPredDetail) {
            slicedResult.set(1, cluster.f1);
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        boolean contains = params.contains(ClusteringPredictParams.PREDICTION_DETAIL_COL);
        return Tuple4.of(tableSchema2.getFieldNames(), contains ? new String[]{(String) params.get(ClusteringPredictParams.PREDICTION_COL), (String) params.get(ClusteringPredictParams.PREDICTION_DETAIL_COL)} : new String[]{(String) params.get(ClusteringPredictParams.PREDICTION_COL)}, contains ? new TypeInformation[]{Types.LONG, Types.DOUBLE} : new TypeInformation[]{Types.LONG}, params.get(ClusteringPredictParams.RESERVED_COLS));
    }

    private Tuple2<Long, Double> getCluster(Iterable<Tuple3<Long, Double, String[]>> iterable, String[] strArr, OneZeroDistance oneZeroDistance) {
        long j = -1;
        double d = Double.POSITIVE_INFINITY;
        for (Tuple3<Long, Double, String[]> tuple3 : iterable) {
            double calc = oneZeroDistance.calc(strArr, (String[]) tuple3.f2);
            if (calc < d) {
                j = ((Long) tuple3.f0).longValue();
                d = calc;
            }
        }
        return new Tuple2<>(Long.valueOf(j), Double.valueOf(d));
    }
}
