package com.alibaba.alink.operator.common.regression;

import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.dataproc.ScalerUtil;
import com.alibaba.alink.params.regression.IsotonicRegPredictParams;
import com.alibaba.alink.params.regression.IsotonicRegTrainParams;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/regression/IsotonicRegressionModelMapper.class */
public class IsotonicRegressionModelMapper extends ModelMapper {
    private static final long serialVersionUID = 4565470971830328037L;
    private int colIdx;
    private IsotonicRegressionModelData modelData;
    private String vectorColName;
    private int featureIndex;

    public IsotonicRegressionModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.modelData = new IsotonicRegressionConverter().load(list);
        Params params = this.modelData.meta;
        String str = (String) params.get(IsotonicRegTrainParams.FEATURE_COL);
        this.vectorColName = (String) params.get(IsotonicRegTrainParams.VECTOR_COL);
        this.featureIndex = ((Integer) params.get(IsotonicRegTrainParams.FEATURE_INDEX)).intValue();
        TableSchema dataSchema = getDataSchema();
        if (null == this.vectorColName) {
            this.colIdx = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), str);
        } else {
            this.colIdx = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), this.vectorColName);
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        return Tuple4.of(tableSchema2.getFieldNames(), new String[]{(String) this.params.get(IsotonicRegPredictParams.PREDICTION_COL)}, new TypeInformation[]{Types.DOUBLE}, (Object) null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        if (null == slicedSelectedSample.get(this.colIdx)) {
            slicedResult.set(0, null);
            return;
        }
        double doubleValue = null == this.vectorColName ? ((Number) slicedSelectedSample.get(this.colIdx)).doubleValue() : VectorUtil.getVector(slicedSelectedSample.get(this.colIdx)).get(this.featureIndex);
        int binarySearch = Arrays.binarySearch(this.modelData.boundaries, Double.valueOf(doubleValue));
        int i = (-binarySearch) - 1;
        slicedResult.set(0, Double.valueOf(i == 0 ? this.modelData.values[0].doubleValue() : i == this.modelData.boundaries.length ? this.modelData.values[this.modelData.values.length - 1].doubleValue() : binarySearch < 0 ? ScalerUtil.minMaxScaler(doubleValue, this.modelData.boundaries[i - 1].doubleValue(), this.modelData.boundaries[i].doubleValue(), this.modelData.values[i].doubleValue(), this.modelData.values[i - 1].doubleValue()) : this.modelData.values[binarySearch].doubleValue()));
    }
}
