package com.alibaba.alink.operator.common.dataproc.vector;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.SISOModelMapper;
import com.alibaba.alink.common.model.RichModelDataConverter;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.dataproc.vector.VectorSrtPredictorParams;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/vector/VectorStandardScalerModelMapper.class */
public class VectorStandardScalerModelMapper extends SISOModelMapper {
    private static final long serialVersionUID = -4913148799471049498L;
    private double[] means;
    private double[] stdDeviations;
    private Boolean withMean;

    public VectorStandardScalerModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params.set((ParamInfo<ParamInfo<String>>) VectorSrtPredictorParams.SELECTED_COL, (ParamInfo<String>) RichModelDataConverter.extractSelectedColNames(tableSchema)[0]));
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected TypeInformation initPredResultColType() {
        return AlinkTypes.VECTOR;
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected Object predictResult(Object obj) {
        if (null == obj) {
            return null;
        }
        Vector vector = VectorUtil.getVector(obj);
        return vector instanceof DenseVector ? predict((DenseVector) vector) : predict((SparseVector) vector);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        Tuple4<Boolean, Boolean, double[], double[]> load = new VectorStandardScalerModelDataConverter().load(list);
        this.withMean = (Boolean) load.f0;
        this.means = (double[]) load.f2;
        this.stdDeviations = (double[]) load.f3;
    }

    private DenseVector predict(DenseVector denseVector) {
        double[] data = denseVector.getData();
        for (int i = 0; i < denseVector.size(); i++) {
            if (this.stdDeviations[i] != Criteria.INVALID_GAIN) {
                data[i] = (data[i] - this.means[i]) / this.stdDeviations[i];
            } else {
                data[i] = 0.0d;
            }
        }
        return denseVector;
    }

    private Vector predict(SparseVector sparseVector) {
        int[] indices = sparseVector.getIndices();
        double[] values = sparseVector.getValues();
        if (!this.withMean.booleanValue()) {
            for (int i = 0; i < sparseVector.numberOfValues(); i++) {
                int i2 = indices[i];
                if (this.stdDeviations[i2] != Criteria.INVALID_GAIN) {
                    values[i] = values[i] / this.stdDeviations[i2];
                } else {
                    values[i] = 0.0d;
                }
            }
            return sparseVector;
        }
        DenseVector denseVector = new DenseVector(this.means.length);
        double[] data = denseVector.getData();
        int i3 = 0;
        for (int i4 = 0; i4 < data.length && i3 != indices.length; i4++) {
            if (this.stdDeviations[i4] != Criteria.INVALID_GAIN) {
                if (indices[i3] == i4) {
                    int i5 = i3;
                    i3++;
                    data[i4] = (values[i5] - this.means[i4]) / this.stdDeviations[i4];
                } else {
                    data[i4] = (-this.means[i4]) / this.stdDeviations[i4];
                }
            } else if (indices[i3] == i4) {
                i3++;
            }
        }
        return denseVector;
    }
}
