package com.alibaba.alink.operator.common.dataproc.vector;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.SISOModelMapper;
import com.alibaba.alink.common.model.RichModelDataConverter;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.operator.common.dataproc.ScalerUtil;
import com.alibaba.alink.params.dataproc.vector.VectorSrtPredictorParams;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/vector/VectorMaxAbsScalerModelMapper.class */
public class VectorMaxAbsScalerModelMapper extends SISOModelMapper {
    private static final long serialVersionUID = -8019343040061141104L;
    private double[] maxAbs;

    public VectorMaxAbsScalerModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params.set((ParamInfo<ParamInfo<String>>) VectorSrtPredictorParams.SELECTED_COL, (ParamInfo<String>) RichModelDataConverter.extractSelectedColNames(tableSchema)[0]));
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected TypeInformation initPredResultColType() {
        return AlinkTypes.VECTOR;
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected Object predictResult(Object obj) {
        Vector vector = VectorUtil.getVector(obj);
        if (null == vector) {
            return null;
        }
        return vector instanceof DenseVector ? predict((DenseVector) vector) : predict((SparseVector) vector);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.maxAbs = new VectorMaxAbsScalerModelDataConverter().load(list);
    }

    private DenseVector predict(DenseVector denseVector) {
        double[] data = denseVector.getData();
        for (int i = 0; i < denseVector.size(); i++) {
            data[i] = ScalerUtil.maxAbsScaler(this.maxAbs[i], data[i]);
        }
        return denseVector;
    }

    private SparseVector predict(SparseVector sparseVector) {
        for (int i = 0; i < sparseVector.numberOfValues(); i++) {
            sparseVector.getValues()[i] = ScalerUtil.maxAbsScaler(this.maxAbs[sparseVector.getIndices()[i]], sparseVector.getValues()[i]);
        }
        return sparseVector;
    }
}
