package com.alibaba.alink.operator.common.feature;

import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalModelException;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.params.dataproc.HasHandleInvalid;
import com.alibaba.alink.params.feature.HasEncodeWithoutWoeAndIndex;
import com.alibaba.alink.params.feature.MultiHotPredictParams;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/feature/MultiHotModelMapper.class */
public class MultiHotModelMapper extends ModelMapper {
    private static final long serialVersionUID = 7431062592310976413L;
    private MultiHotModelData model;
    private final String[] inputPredictColNames;
    private final HasHandleInvalid.HandleInvalid handleInvalid;
    private final HasEncodeWithoutWoeAndIndex.Encode encode;
    private int offsetSize;
    boolean enableElse;

    public MultiHotModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.offsetSize = 0;
        this.enableElse = false;
        this.handleInvalid = (HasHandleInvalid.HandleInvalid) params.get(MultiHotPredictParams.HANDLE_INVALID);
        this.encode = (HasEncodeWithoutWoeAndIndex.Encode) params.get(MultiHotPredictParams.ENCODE);
        if (this.handleInvalid.equals(HasHandleInvalid.HandleInvalid.KEEP)) {
            this.offsetSize = 1;
        }
        if (params.contains(MultiHotPredictParams.SELECTED_COLS)) {
            this.inputPredictColNames = (String[]) params.get(MultiHotPredictParams.SELECTED_COLS);
        } else {
            this.inputPredictColNames = null;
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.model = new MultiHotModelDataConverter().load(list);
        this.enableElse = this.model.getEnableElse(this.inputPredictColNames);
        if (null != this.inputPredictColNames) {
            Set<String> keySet = this.model.modelData.keySet();
            for (String str : this.inputPredictColNames) {
                if (!keySet.contains(str)) {
                    throw new AkIllegalArgumentException("Column '" + str + "' has not been precessed in OneHot model training.");
                }
            }
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        String[] strArr = (String[]) params.get(MultiHotPredictParams.RESERVED_COLS);
        if (strArr == null) {
            strArr = tableSchema2.getFieldNames();
        }
        String[] strArr2 = (String[]) params.get(MultiHotPredictParams.OUTPUT_COLS);
        TypeInformation[] typeInformationArr = new TypeInformation[strArr2.length];
        Arrays.fill(typeInformationArr, AlinkTypes.SPARSE_VECTOR);
        return Tuple4.of(params.get(MultiHotPredictParams.SELECTED_COLS), strArr2, typeInformationArr, strArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        if (this.encode.equals(HasEncodeWithoutWoeAndIndex.Encode.ASSEMBLED_VECTOR)) {
            Tuple2<Integer, int[]> indicesAndSize = getIndicesAndSize(slicedSelectedSample);
            double[] dArr = new double[((int[]) indicesAndSize.f1).length];
            Arrays.fill(dArr, 1.0d);
            if (((int[]) indicesAndSize.f1).length != 0) {
                slicedResult.set(0, new SparseVector(((Integer) indicesAndSize.f0).intValue(), (int[]) indicesAndSize.f1, dArr));
                return;
            }
            return;
        }
        if (this.encode.equals(HasEncodeWithoutWoeAndIndex.Encode.VECTOR)) {
            for (int i = 0; i < slicedSelectedSample.length(); i++) {
                Tuple2<Integer, int[]> singleIndicesAndSize = getSingleIndicesAndSize(this.inputPredictColNames[i], (String) slicedSelectedSample.get(i));
                double[] dArr2 = new double[((int[]) singleIndicesAndSize.f1).length];
                Arrays.fill(dArr2, 1.0d);
                if (((int[]) singleIndicesAndSize.f1).length != 0) {
                    slicedResult.set(i, new SparseVector(((Integer) singleIndicesAndSize.f0).intValue(), (int[]) singleIndicesAndSize.f1, dArr2));
                }
            }
        }
    }

    public Tuple2<Integer, int[]> getSingleIndicesAndSize(String str, String str2) {
        HashSet hashSet = new HashSet();
        Map<String, Tuple2<Integer, Integer>> map = this.model.modelData.get(str);
        if (str2 != null) {
            for (String str3 : str2.split(this.model.delimiter)) {
                Tuple2<Integer, Integer> tuple2 = map.get(str3.trim());
                if (tuple2 == null) {
                    switch (this.handleInvalid) {
                        case KEEP:
                            hashSet.add(Integer.valueOf(map.size() + (this.enableElse ? 1 : 0)));
                            break;
                        case ERROR:
                            throw new AkIllegalModelException("multi hot encoder err, key is not exist.");
                    }
                } else if (((Integer) tuple2.f0).intValue() != -1) {
                    hashSet.add(tuple2.f0);
                } else {
                    hashSet.add(Integer.valueOf(map.size()));
                }
            }
        }
        int[] iArr = new int[hashSet.size()];
        Iterator it = hashSet.iterator();
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = ((Integer) it.next()).intValue();
        }
        return Tuple2.of(Integer.valueOf(map.size() + (this.enableElse ? 1 : 0) + this.offsetSize), iArr);
    }

    public Tuple2<Integer, int[]> getIndicesAndSize(Mapper.SlicedSelectedSample slicedSelectedSample) {
        HashSet hashSet = new HashSet();
        int i = 0;
        for (int i2 = 0; i2 < slicedSelectedSample.length(); i2++) {
            Tuple2<Integer, int[]> singleIndicesAndSize = getSingleIndicesAndSize(this.inputPredictColNames[i2], (String) slicedSelectedSample.get(i2));
            for (int i3 = 0; i3 < ((int[]) singleIndicesAndSize.f1).length; i3++) {
                hashSet.add(Integer.valueOf(i + ((int[]) singleIndicesAndSize.f1)[i3]));
            }
            i += ((Integer) singleIndicesAndSize.f0).intValue();
        }
        int[] iArr = new int[hashSet.size()];
        Iterator it = hashSet.iterator();
        for (int i4 = 0; i4 < iArr.length; i4++) {
            iArr[i4] = ((Integer) it.next()).intValue();
        }
        return Tuple2.of(Integer.valueOf(i), iArr);
    }
}
