package com.alibaba.alink.operator.common.dataproc;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.dataproc.HasHandleInvalid;
import com.alibaba.alink.params.dataproc.MultiStringIndexerPredictParams;
import com.alibaba.alink.params.shared.colname.HasSelectedCols;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/MultiStringIndexerModelMapper.class */
public class MultiStringIndexerModelMapper extends ModelMapper {
    private static final long serialVersionUID = 7434426152864663314L;
    private transient Map<Integer, Map<String, Long>> indexMapper;
    private transient Map<Integer, Long> defaultIndex;
    private final String[] selectedColNames;
    private final HasHandleInvalid.HandleInvalid handleInvalidStrategy;

    public MultiStringIndexerModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.handleInvalidStrategy = (HasHandleInvalid.HandleInvalid) this.params.get(MultiStringIndexerPredictParams.HANDLE_INVALID);
        this.selectedColNames = (String[]) this.params.get(MultiStringIndexerPredictParams.SELECTED_COLS);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        MultiStringIndexerModelData load = new MultiStringIndexerModelDataConverter().load(list);
        int[] findColIndicesWithAssert = TableUtil.findColIndicesWithAssert((String[]) load.meta.get(HasSelectedCols.SELECTED_COLS), this.selectedColNames);
        this.indexMapper = new HashMap();
        this.defaultIndex = new HashMap();
        for (int i = 0; i < this.selectedColNames.length; i++) {
            HashMap hashMap = new HashMap();
            int i2 = findColIndicesWithAssert[i];
            for (Tuple3<Integer, String, Long> tuple3 : load.tokenAndIndex) {
                if (((Integer) tuple3.f0).intValue() == i2) {
                    hashMap.put((String) tuple3.f1, (Long) tuple3.f2);
                }
            }
            this.indexMapper.put(Integer.valueOf(i), hashMap);
            this.defaultIndex.put(Integer.valueOf(i), load.tokenNumber.get(Integer.valueOf(i2)));
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        String[] strArr = (String[]) params.get(MultiStringIndexerPredictParams.SELECTED_COLS);
        String[] strArr2 = (String[]) params.get(MultiStringIndexerPredictParams.OUTPUT_COLS);
        if (strArr2 == null) {
            strArr2 = strArr;
        }
        AkPreconditions.checkArgument(strArr2.length == strArr.length, (ExceptionWithErrorCode) new AkIllegalOperatorParameterException("OutputCol length must be equal to selectedCol length!"));
        String[] strArr3 = (String[]) params.get(MultiStringIndexerPredictParams.RESERVED_COLS);
        TypeInformation[] typeInformationArr = new TypeInformation[strArr.length];
        Arrays.fill(typeInformationArr, Types.LONG);
        return Tuple4.of(strArr, strArr2, typeInformationArr, strArr3);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        for (int i = 0; i < this.selectedColNames.length; i++) {
            Object obj = slicedSelectedSample.get(i);
            String valueOf = obj == null ? null : String.valueOf(obj);
            Long l = this.indexMapper.get(Integer.valueOf(i)).get(valueOf);
            if (l != null) {
                slicedResult.set(i, l);
            } else {
                switch (this.handleInvalidStrategy) {
                    case KEEP:
                        slicedResult.set(i, this.defaultIndex.get(Integer.valueOf(i)));
                        break;
                    case SKIP:
                        slicedResult.set(i, null);
                        break;
                    case ERROR:
                        throw new AkIllegalDataException("Unseen token: " + valueOf);
                    default:
                        throw new AkUnsupportedOperationException("Invalid handle invalid strategy.");
                }
            }
        }
    }
}
