package com.alibaba.alink.operator.common.dataproc.vector;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.SISOModelMapper;
import com.alibaba.alink.common.model.RichModelDataConverter;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.operator.common.dataproc.ScalerUtil;
import com.alibaba.alink.params.dataproc.vector.VectorSrtPredictorParams;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/vector/VectorMinMaxScalerModelMapper.class */
public class VectorMinMaxScalerModelMapper extends SISOModelMapper {
    private static final long serialVersionUID = 8736845046919733011L;
    private double[] eMins;
    private double[] eMaxs;
    private double min;
    private double max;

    public VectorMinMaxScalerModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params.set((ParamInfo<ParamInfo<String>>) VectorSrtPredictorParams.SELECTED_COL, (ParamInfo<String>) RichModelDataConverter.extractSelectedColNames(tableSchema)[0]));
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected TypeInformation initPredResultColType() {
        return AlinkTypes.VECTOR;
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    public Object predictResult(Object obj) {
        Vector vector = VectorUtil.getVector(obj);
        if (null == vector) {
            return null;
        }
        return vector instanceof DenseVector ? predict((DenseVector) vector) : predict((SparseVector) vector);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        Tuple4<Double, Double, double[], double[]> load = new VectorMinMaxScalerModelDataConverter().load(list);
        this.min = ((Double) load.f0).doubleValue();
        this.max = ((Double) load.f1).doubleValue();
        this.eMins = (double[]) load.f2;
        this.eMaxs = (double[]) load.f3;
    }

    private DenseVector predict(DenseVector denseVector) {
        double[] data = denseVector.getData();
        for (int i = 0; i < denseVector.size(); i++) {
            data[i] = ScalerUtil.minMaxScaler(data[i], this.eMins[i], this.eMaxs[i], this.max, this.min);
        }
        return denseVector;
    }

    private Vector predict(SparseVector sparseVector) {
        DenseVector denseVector = new DenseVector(this.eMaxs.length);
        double[] data = denseVector.getData();
        for (int i = 0; i < data.length; i++) {
            data[i] = ScalerUtil.minMaxScaler(sparseVector.get(i), this.eMins[i], this.eMaxs[i], this.max, this.min);
        }
        return denseVector;
    }
}
