package com.alibaba.alink.operator.common.feature.pca;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.feature.HasCalculationType;
import com.alibaba.alink.params.feature.PcaPredictParams;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.class */
public class PcaModelMapper extends ModelMapper {
    private static final long serialVersionUID = -6656670267982283314L;
    private PcaModelData model;
    private int[] featureIdxs;
    private boolean isVector;
    private HasCalculationType.CalculationType pcaType;
    private double[] sourceMean;
    private double[] sourceStd;

    public PcaModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.model = null;
        this.featureIdxs = null;
        this.pcaType = null;
        this.sourceMean = null;
        this.sourceStd = null;
    }

    private int[] checkGetColIndices(Boolean bool, String[] strArr, String str) {
        String[] fieldNames = getDataSchema().getFieldNames();
        if (bool.booleanValue()) {
            TableUtil.assertSelectedColExist(fieldNames, str);
            TableUtil.assertVectorCols(getDataSchema(), str);
            return new int[]{TableUtil.findColIndexWithAssertAndHint(fieldNames, str)};
        }
        TableUtil.assertSelectedColExist(fieldNames, strArr);
        TableUtil.assertNumericalCols(getDataSchema(), strArr);
        return TableUtil.findColIndicesWithAssertAndHint(fieldNames, strArr);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.model = new PcaModelDataConverter().load(list);
        String[] strArr = this.model.featureColNames;
        String str = this.model.vectorColName;
        if (this.params.contains(PcaPredictParams.VECTOR_COL)) {
            str = (String) this.params.get(PcaPredictParams.VECTOR_COL);
        }
        if (str != null) {
            this.isVector = true;
        }
        this.featureIdxs = checkGetColIndices(Boolean.valueOf(this.isVector), strArr, str);
        this.pcaType = this.model.pcaType;
        int length = this.model.means.length;
        this.sourceMean = new double[length];
        this.sourceStd = new double[length];
        Arrays.fill(this.sourceStd, 1.0d);
        if (HasCalculationType.CalculationType.CORR == this.pcaType) {
            this.sourceStd = this.model.stddevs;
            this.sourceMean = this.model.means;
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        return Tuple4.of(tableSchema2.getFieldNames(), new String[]{(String) params.get(PcaPredictParams.PREDICTION_COL)}, new TypeInformation[]{AlinkTypes.DENSE_VECTOR}, params.get(PcaPredictParams.RESERVED_COLS));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        double[] calcPrinValue;
        double[] dArr = new double[this.model.nx];
        if (this.isVector) {
            Vector vector = VectorUtil.getVector(slicedSelectedSample.get(this.featureIdxs[0]));
            if ((vector instanceof SparseVector) && vector.size() < 0) {
                ((SparseVector) vector).setSize(this.model.nx);
            }
            for (int i = 0; i < vector.size(); i++) {
                dArr[i] = vector.get(i);
            }
        } else {
            for (int i2 = 0; i2 < this.featureIdxs.length; i2++) {
                dArr[i2] = ((Double) slicedSelectedSample.get(this.featureIdxs[i2])).doubleValue();
            }
        }
        if (this.model.idxNonEqual.length != dArr.length) {
            Integer[] numArr = this.model.idxNonEqual;
            double[] dArr2 = new double[numArr.length];
            for (int i3 = 0; i3 < numArr.length; i3++) {
                if (Math.abs(this.sourceStd[i3]) > 1.0E-12d) {
                    dArr2[i3] = (dArr[numArr[i3].intValue()] - this.sourceMean[i3]) / this.sourceStd[i3];
                }
            }
            calcPrinValue = this.model.calcPrinValue(dArr2);
        } else {
            for (int i4 = 0; i4 < dArr.length; i4++) {
                if (Math.abs(this.sourceStd[i4]) > 1.0E-12d) {
                    dArr[i4] = (dArr[i4] - this.sourceMean[i4]) / this.sourceStd[i4];
                }
            }
            calcPrinValue = this.model.calcPrinValue(dArr);
        }
        slicedResult.set(0, new DenseVector(calcPrinValue));
    }
}
