package com.alibaba.alink.operator.common.dataproc;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.dataproc.LookupParams;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/LookupModelMapper.class */
public class LookupModelMapper extends ModelMapper {
    protected final int[] selectedColIndices;
    private final int[] mapKeyColIndices;
    private final int[] mapValueColIndices;
    private HashMap<List<Object>, Object[]> mapModel;
    private String[] outputColNames;
    private final int mk;
    private final List<Object> currentKey;

    public LookupModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        String[] strArr = (String[]) params.get(LookupParams.MAP_KEY_COLS);
        this.mapKeyColIndices = strArr != null ? TableUtil.findColIndicesWithAssertAndHint(tableSchema, strArr) : new int[]{0};
        String[] strArr2 = (String[]) params.get(LookupParams.MAP_VALUE_COLS);
        this.mapValueColIndices = strArr2 != null ? TableUtil.findColIndicesWithAssertAndHint(tableSchema, strArr2) : new int[]{1};
        if (tableSchema.getFieldNames().length > 2 && (strArr == null || strArr2 == null)) {
            throw new AkIllegalOperatorParameterException("LookUpMapper err : mapKeyCols and mapValueCols should set in parameters.");
        }
        this.mk = this.mapKeyColIndices.length;
        this.currentKey = new ArrayList(this.mk);
        for (int i = 0; i < this.mk; i++) {
            this.currentKey.add(null);
        }
        String[] strArr3 = (String[]) params.get(LookupParams.SELECTED_COLS);
        this.selectedColIndices = TableUtil.findColIndicesWithAssertAndHint(tableSchema2, strArr3);
        for (int i2 = 0; i2 < strArr3.length; i2++) {
            if (strArr != null && strArr2 != null && TableUtil.findColTypeWithAssertAndHint(tableSchema2, strArr3[i2]) != TableUtil.findColTypeWithAssertAndHint(tableSchema, strArr[i2])) {
                throw new AkIllegalDataException("Data types are not match. selected column type is " + TableUtil.findColTypeWithAssertAndHint(tableSchema2, strArr3[i2]) + " , and the map key column type is " + TableUtil.findColTypeWithAssertAndHint(tableSchema, strArr[i2]));
            }
        }
        this.outputColNames = (String[]) params.get(LookupParams.OUTPUT_COLS);
        if (null == this.outputColNames) {
            this.outputColNames = strArr2;
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        int length = this.mapValueColIndices.length;
        this.mapModel = new HashMap<>(list.size());
        for (Row row : list) {
            Object[] objArr = new Object[length];
            for (int i = 0; i < length; i++) {
                objArr[i] = row.getField(this.mapValueColIndices[i]);
            }
            ArrayList arrayList = new ArrayList(this.mk);
            for (int i2 = 0; i2 < this.mk; i2++) {
                arrayList.add(row.getField(this.mapKeyColIndices[i2]));
            }
            this.mapModel.put(arrayList, objArr);
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public ModelMapper createNew(List<Row> list) {
        int length = this.mapValueColIndices.length;
        HashMap<List<Object>, Object[]> hashMap = new HashMap<>(this.mapModel.size());
        switch ((LookupParams.ModelStreamUpdateMethod) this.params.get(LookupParams.MODEL_STREAM_UPDATE_METHOD)) {
            case INCREMENT:
                for (List<Object> list2 : this.mapModel.keySet()) {
                    hashMap.put(list2, this.mapModel.get(list2));
                }
                break;
        }
        for (Row row : list) {
            Object[] objArr = new Object[length];
            for (int i = 0; i < length; i++) {
                objArr[i] = row.getField(this.mapValueColIndices[i]);
            }
            ArrayList arrayList = new ArrayList(this.mk);
            for (int i2 = 0; i2 < this.mk; i2++) {
                arrayList.add(row.getField(this.mapKeyColIndices[i2]));
            }
            hashMap.put(arrayList, objArr);
        }
        this.mapModel = hashMap;
        return this;
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        this.outputColNames = (String[]) params.get(LookupParams.OUTPUT_COLS);
        String[] strArr = (String[]) params.get(LookupParams.MAP_VALUE_COLS);
        String[] strArr2 = (String[]) params.get(LookupParams.SELECTED_COLS);
        if (null == this.outputColNames) {
            this.outputColNames = strArr;
        }
        return Tuple4.of(strArr2, this.outputColNames, strArr == null ? TableUtil.findColTypesWithAssertAndHint(tableSchema, new String[]{tableSchema.getFieldNames()[1]}) : TableUtil.findColTypesWithAssertAndHint(tableSchema, strArr), params.get(LookupParams.RESERVED_COLS));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        for (int i = 0; i < this.mk; i++) {
            this.currentKey.set(i, slicedSelectedSample.get(i));
        }
        Object[] objArr = this.mapModel.get(this.currentKey);
        if (null == objArr) {
            objArr = new Object[this.mapValueColIndices.length];
        }
        for (int i2 = 0; i2 < objArr.length; i2++) {
            slicedResult.set(i2, objArr[i2]);
        }
    }
}
