package com.alibaba.alink.operator.common.feature;

import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelDataConverter;
import com.alibaba.alink.params.feature.CrossFeaturePredictParams;
import com.alibaba.alink.params.feature.CrossFeatureTrainParams;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/feature/CrossFeatureModelMapper.class */
public class CrossFeatureModelMapper extends ModelMapper {
    int[] selectedColIndices;
    String[] dataColNames;
    HashMap<String, Integer>[] tokenAndIndex;
    int[] nullIndex;
    int[] carry;
    int[] dataIndices;
    int svLength;

    public CrossFeatureModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.svLength = 0;
        this.dataColNames = tableSchema2.getFieldNames();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        for (int i = 0; i < this.selectedColIndices.length; i++) {
            Object obj = slicedSelectedSample.get(this.selectedColIndices[i]);
            if (obj == null) {
                if (this.nullIndex[i] == -1) {
                    slicedResult.set(0, new SparseVector(this.svLength));
                    return;
                }
                this.dataIndices[i] = this.nullIndex[i];
            } else {
                if (!this.tokenAndIndex[i].containsKey(obj.toString())) {
                    slicedResult.set(0, new SparseVector(this.svLength));
                    return;
                }
                this.dataIndices[i] = this.tokenAndIndex[i].get(obj.toString()).intValue();
            }
        }
        int i2 = 0;
        for (int i3 = 0; i3 < this.carry.length; i3++) {
            i2 += this.carry[i3] * this.dataIndices[i3];
        }
        slicedResult.set(0, new SparseVector(this.svLength, new int[]{i2}, new double[]{1.0d}));
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        return Tuple4.of(tableSchema2.getFieldNames(), new String[]{(String) params.get(CrossFeaturePredictParams.OUTPUT_COL)}, new TypeInformation[]{AlinkTypes.SPARSE_VECTOR}, tableSchema2.getFieldNames());
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        MultiStringIndexerModelData load = new MultiStringIndexerModelDataConverter().load(list);
        this.selectedColIndices = TableUtil.findColIndices(this.dataColNames, (String[]) load.meta.get(CrossFeatureTrainParams.SELECTED_COLS));
        int size = load.tokenNumber.size();
        this.tokenAndIndex = new HashMap[size];
        this.nullIndex = new int[size];
        Arrays.fill(this.nullIndex, -1);
        this.carry = new int[size];
        this.carry[0] = 1;
        for (int i = 0; i < size - 1; i++) {
            this.carry[i + 1] = (int) (load.tokenNumber.get(Integer.valueOf(i)).longValue() * this.carry[i]);
        }
        this.svLength = this.carry[size - 1] * load.tokenNumber.get(Integer.valueOf(size - 1)).intValue();
        for (int i2 = 0; i2 < size; i2++) {
            this.tokenAndIndex[i2] = new HashMap<>(load.tokenNumber.get(Integer.valueOf(i2)).intValue());
        }
        for (Tuple3<Integer, String, Long> tuple3 : load.tokenAndIndex) {
            if (tuple3.f1 == null) {
                this.nullIndex[((Integer) tuple3.f0).intValue()] = ((Long) tuple3.f2).intValue();
            } else {
                this.tokenAndIndex[((Integer) tuple3.f0).intValue()].put(tuple3.f1, Integer.valueOf(((Long) tuple3.f2).intValue()));
            }
        }
        this.dataIndices = new int[size];
    }
}
