package com.alibaba.alink.operator.common.dataproc.vector;

import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.SISOModelMapper;
import com.alibaba.alink.common.model.RichModelDataConverter;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.params.dataproc.HasStrategy;
import com.alibaba.alink.params.dataproc.vector.VectorSrtPredictorParams;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/vector/VectorImputerModelMapper.class */
public class VectorImputerModelMapper extends SISOModelMapper {
    private static final long serialVersionUID = 961247156517825658L;
    private double[] defaultValueArray;
    private double defaultValue;
    private boolean useOneDefaultValue;

    public VectorImputerModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params.set((ParamInfo<ParamInfo<String>>) VectorSrtPredictorParams.SELECTED_COL, (ParamInfo<String>) RichModelDataConverter.extractSelectedColNames(tableSchema)[0]));
        this.useOneDefaultValue = false;
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected TypeInformation initPredResultColType() {
        return AlinkTypes.VECTOR;
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    public Object predictResult(Object obj) {
        Vector vector = VectorUtil.getVector(obj);
        if (null == vector) {
            return null;
        }
        return vector instanceof DenseVector ? predict((DenseVector) vector) : predict((SparseVector) vector);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        Tuple3<HasStrategy.Strategy, double[], Double> load = new VectorImputerModelDataConverter().load(list);
        this.defaultValueArray = (double[]) load.f1;
        if (this.defaultValueArray == null || this.defaultValueArray.length == 0) {
            if (load.f2 == null) {
                throw new AkIllegalOperatorParameterException("In VALUE strategy, the filling value is necessary.");
            }
            this.defaultValue = ((Double) load.f2).doubleValue();
            this.useOneDefaultValue = true;
        }
    }

    private DenseVector predict(DenseVector denseVector) {
        double[] data = denseVector.getData();
        if (this.useOneDefaultValue) {
            for (int i = 0; i < data.length; i++) {
                if (Double.isNaN(data[i])) {
                    data[i] = this.defaultValue;
                }
            }
        } else {
            for (int i2 = 0; i2 < data.length; i2++) {
                if (Double.isNaN(data[i2])) {
                    data[i2] = this.defaultValueArray[i2];
                }
            }
        }
        return denseVector;
    }

    private SparseVector predict(SparseVector sparseVector) {
        double[] values = sparseVector.getValues();
        if (this.useOneDefaultValue) {
            for (int i = 0; i < sparseVector.numberOfValues(); i++) {
                if (Double.isNaN(values[i])) {
                    values[i] = this.defaultValue;
                }
            }
        } else {
            for (int i2 = 0; i2 < sparseVector.numberOfValues(); i2++) {
                if (Double.isNaN(values[i2])) {
                    values[i2] = this.defaultValueArray[sparseVector.getIndices()[i2]];
                }
            }
        }
        return sparseVector;
    }
}
