package com.alibaba.alink.operator.common.clustering.dbscan;

import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.distance.FastDistanceData;
import com.alibaba.alink.operator.common.distance.FastDistanceVectorData;
import com.alibaba.alink.params.clustering.ClusteringPredictParams;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelMapper.class */
public class DbscanModelMapper extends ModelMapper {
    private static final long serialVersionUID = -3771648601253028057L;
    private DbscanModelPredictData modelData;
    private int colIdx;

    public DbscanModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.modelData = null;
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.modelData = new DbscanModelDataConverter().load(list);
        this.colIdx = TableUtil.findColIndexWithAssert(getDataSchema().getFieldNames(), this.modelData.vectorColName);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        slicedResult.set(0, Long.valueOf(findCluster(VectorUtil.getVector(slicedSelectedSample.get(this.colIdx)))));
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        return Tuple4.of(tableSchema2.getFieldNames(), new String[]{(String) params.get(ClusteringPredictParams.PREDICTION_COL)}, new TypeInformation[]{Types.LONG}, params.get(ClusteringPredictParams.RESERVED_COLS));
    }

    private long findCluster(Vector vector) {
        long j = -1;
        double d = Double.POSITIVE_INFINITY;
        FastDistanceVectorData prepareVectorData = this.modelData.baseDistance.prepareVectorData(Row.of(new Object[]{vector}), 0, new int[0]);
        for (FastDistanceVectorData fastDistanceVectorData : this.modelData.coreObjects) {
            double d2 = this.modelData.baseDistance.calc((FastDistanceData) fastDistanceVectorData, (FastDistanceData) prepareVectorData).get(0, 0);
            if (d2 < d) {
                j = ((Long) fastDistanceVectorData.getRows()[0].getField(0)).longValue();
                d = d2;
            }
        }
        if (d > this.modelData.epsilon) {
            j = -2147483648L;
        }
        return j;
    }
}
