package com.alibaba.alink.operator.common.feature.AutoCross;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.feature.AutoCrossPredictParams;
import com.alibaba.alink.params.feature.featuregenerator.HasAppendOriginalData;
import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull;
import java.io.Serializable;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.function.TriFunction;

/* loaded from: input_file:com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossAlgoModelMapper.class */
public class AutoCrossAlgoModelMapper extends ModelMapper {
    private static final long serialVersionUID = -4500389710522943248L;
    private String[] dataCols;
    private int[] numericalIndices;
    private int vecIndex;
    private OneHotOperator operator;
    private int[] cumsumIndex;
    private final AutoCrossPredictParams.OutputFormat outputFormat;
    TriFunction<Tuple4<Mapper.SlicedSelectedSample, Integer, int[], int[]>, OneHotOperator, Mapper.SlicedResult, Row> mapOperator;
    boolean appendOriginalVec;

    /* loaded from: input_file:com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossAlgoModelMapper$FeatureSet.class */
    public static class FeatureSet implements Serializable {
        private static final long serialVersionUID = 3402906686076385472L;
        public int numRawFeatures;
        public String[] numericalCols;
        public String vecColName;
        public List<int[]> crossFeatureSet;
        public List<Double> scores;
        public int[] indexSize;
        public boolean hasDiscrete;
    }

    public AutoCrossAlgoModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.vecIndex = -2;
        this.appendOriginalVec = true;
        this.outputFormat = (AutoCrossPredictParams.OutputFormat) params.get(AutoCrossPredictParams.OUTPUT_FORMAT);
        this.appendOriginalVec = ((Boolean) params.get(HasAppendOriginalData.APPEND_ORIGINAL_DATA)).booleanValue();
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        this.dataCols = tableSchema2.getFieldNames();
        return Tuple4.of(this.dataCols, new String[]{(String) params.get(AutoCrossPredictParams.OUTPUT_COL)}, new TypeInformation[]{AlinkTypes.VECTOR}, params.get(HasReservedColsDefaultAsNull.RESERVED_COLS));
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        FeatureSet featureSet = (FeatureSet) JsonConverter.fromJson((String) ((List) list.stream().filter(row -> {
            return row.getField(0).equals(0L);
        }).map(row2 -> {
            return (String) row2.getField(1);
        }).collect(Collectors.toList())).get(0), FeatureSet.class);
        this.numericalIndices = TableUtil.findColIndices(this.dataCols, featureSet.numericalCols);
        if (this.vecIndex == -2) {
            this.vecIndex = TableUtil.findColIndex(this.dataCols, featureSet.vecColName);
        }
        featureSet.crossFeatureSet.size();
        this.operator = new OneHotOperator(featureSet.numRawFeatures, featureSet.crossFeatureSet, featureSet.indexSize);
        if (this.outputFormat == AutoCrossPredictParams.OutputFormat.Dense) {
            this.cumsumIndex = new int[(featureSet.indexSize.length + featureSet.crossFeatureSet.size()) - 1];
            this.cumsumIndex[0] = featureSet.indexSize[0];
            for (int i = 1; i < featureSet.indexSize.length; i++) {
                this.cumsumIndex[i] = this.cumsumIndex[i - 1] + featureSet.indexSize[i];
            }
            for (int i2 = 0; i2 < featureSet.crossFeatureSet.size() - 1; i2++) {
                int i3 = 1;
                for (int i4 : featureSet.crossFeatureSet.get(i2)) {
                    i3 *= i4;
                }
                this.cumsumIndex[featureSet.indexSize.length + i2] = this.cumsumIndex[(featureSet.indexSize.length + i2) - 1] + i3;
            }
        }
        if (this.outputFormat == AutoCrossPredictParams.OutputFormat.Sparse) {
            if (this.appendOriginalVec) {
                this.mapOperator = AutoCrossAlgoModelMapper::mapSparse;
                return;
            } else {
                this.mapOperator = AutoCrossAlgoModelMapper::mapSparseWithoutOriginal;
                return;
            }
        }
        if (this.outputFormat == AutoCrossPredictParams.OutputFormat.Dense) {
            if (this.appendOriginalVec) {
                this.mapOperator = AutoCrossAlgoModelMapper::mapDense;
            } else {
                this.mapOperator = AutoCrossAlgoModelMapper::mapDenseWithoutOriginal;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        this.mapOperator.apply(Tuple4.of(slicedSelectedSample, Integer.valueOf(this.vecIndex), this.cumsumIndex, this.numericalIndices), this.operator, slicedResult);
    }

    private static Row mapSparseWithoutOriginal(Tuple4<Mapper.SlicedSelectedSample, Integer, int[], int[]> tuple4, OneHotOperator oneHotOperator, Mapper.SlicedResult slicedResult) {
        SparseVector sparseVector = VectorUtil.getSparseVector(((Mapper.SlicedSelectedSample) tuple4.f0).get(((Integer) tuple4.f1).intValue()));
        int size = sparseVector.size();
        int numberOfValues = sparseVector.numberOfValues();
        SparseVector oneHotData = oneHotOperator.oneHotData(sparseVector);
        int size2 = oneHotData.size() - size;
        int numberOfValues2 = oneHotData.numberOfValues() - numberOfValues;
        int[] iArr = new int[numberOfValues2];
        double[] dArr = new double[numberOfValues2];
        System.arraycopy(oneHotData.getIndices(), numberOfValues, iArr, 0, numberOfValues2);
        System.arraycopy(oneHotData.getValues(), numberOfValues, dArr, 0, numberOfValues2);
        for (int i = 0; i < numberOfValues2; i++) {
            int i2 = i;
            iArr[i2] = iArr[i2] - size;
        }
        slicedResult.set(0, new SparseVector(size2, iArr, dArr));
        return null;
    }

    private static Row mapSparse(Tuple4<Mapper.SlicedSelectedSample, Integer, int[], int[]> tuple4, OneHotOperator oneHotOperator, Mapper.SlicedResult slicedResult) {
        Mapper.SlicedSelectedSample slicedSelectedSample = (Mapper.SlicedSelectedSample) tuple4.f0;
        int intValue = ((Integer) tuple4.f1).intValue();
        int[] iArr = (int[]) tuple4.f3;
        SparseVector oneHotData = oneHotOperator.oneHotData(VectorUtil.getSparseVector(slicedSelectedSample.get(intValue)));
        int size = oneHotData.size() + iArr.length;
        int[] iArr2 = new int[oneHotData.getIndices().length + iArr.length];
        double[] dArr = new double[oneHotData.getIndices().length + iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr2[i] = i;
            dArr[i] = ((Number) slicedSelectedSample.get(iArr[i])).doubleValue();
        }
        for (int i2 = 0; i2 < oneHotData.getIndices().length; i2++) {
            int[] indices = oneHotData.getIndices();
            int i3 = i2;
            indices[i3] = indices[i3] + iArr.length;
        }
        System.arraycopy(oneHotData.getIndices(), 0, iArr2, iArr.length, oneHotData.getIndices().length);
        System.arraycopy(oneHotData.getValues(), 0, dArr, iArr.length, oneHotData.getValues().length);
        slicedResult.set(0, new SparseVector(size, iArr2, dArr));
        return null;
    }

    private static Row mapDenseWithoutOriginal(Tuple4<Mapper.SlicedSelectedSample, Integer, int[], int[]> tuple4, OneHotOperator oneHotOperator, Mapper.SlicedResult slicedResult) {
        SparseVector sparseVector = (SparseVector) VectorUtil.getVector(((Mapper.SlicedSelectedSample) tuple4.f0).get(((Integer) tuple4.f1).intValue()));
        int size = sparseVector.size();
        int numberOfValues = sparseVector.numberOfValues();
        SparseVector oneHotData = oneHotOperator.oneHotData(sparseVector);
        int[] indices = sparseVector.getIndices();
        int numberOfValues2 = oneHotData.numberOfValues() - numberOfValues;
        double[] dArr = new double[numberOfValues2];
        for (int i = 0; i < numberOfValues2; i++) {
            dArr[i] = indices[i + numberOfValues] - size;
        }
        slicedResult.set(0, new DenseVector(dArr));
        return null;
    }

    private static Row mapDense(Tuple4<Mapper.SlicedSelectedSample, Integer, int[], int[]> tuple4, OneHotOperator oneHotOperator, Mapper.SlicedResult slicedResult) {
        Mapper.SlicedSelectedSample slicedSelectedSample = (Mapper.SlicedSelectedSample) tuple4.f0;
        int intValue = ((Integer) tuple4.f1).intValue();
        int[] iArr = (int[]) tuple4.f3;
        int[] indices = oneHotOperator.oneHotData((SparseVector) VectorUtil.getVector(slicedSelectedSample.get(intValue))).getIndices();
        double[] dArr = new double[iArr.length + indices.length];
        for (int i = 0; i < iArr.length; i++) {
            dArr[i] = ((Number) slicedSelectedSample.get(iArr[i])).doubleValue();
        }
        for (int i2 = 0; i2 < indices.length; i2++) {
            dArr[i2 + iArr.length] = indices[i2];
        }
        slicedResult.set(0, new DenseVector(dArr));
        return null;
    }
}
