package com.alibaba.alink.operator.common.tensorflow;

import com.alibaba.alink.common.dl.plugin.DLPredictorService;
import com.alibaba.alink.common.dl.plugin.TFPredictorClassLoaderFactory;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.dl.HasModelPath;
import com.alibaba.alink.params.tensorflow.savedmodel.BaseTFSavedModelPredictParams;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/BaseTFSavedModelPredictMapper.class */
public class BaseTFSavedModelPredictMapper extends Mapper implements Serializable {
    private final TFPredictorClassLoaderFactory factory;
    protected final String[] tfOutputCols;
    protected final Class<?>[] tfOutputColTypeClasses;
    private final String graphDefTag;
    private final String signatureDefKey;
    protected String[] tfInputCols;
    protected DLPredictorService predictor;
    private String[] inputSignatureDefs;
    private String[] outputSignatureDefs;
    private String modelPath;

    public BaseTFSavedModelPredictMapper(TableSchema tableSchema, Params params) {
        this(tableSchema, params, new TFPredictorClassLoaderFactory());
    }

    public BaseTFSavedModelPredictMapper(TableSchema tableSchema, Params params, TFPredictorClassLoaderFactory tFPredictorClassLoaderFactory) {
        super(tableSchema, params);
        this.factory = tFPredictorClassLoaderFactory;
        this.graphDefTag = (String) params.get(BaseTFSavedModelPredictParams.GRAPH_DEF_TAG);
        this.signatureDefKey = (String) params.get(BaseTFSavedModelPredictParams.SIGNATURE_DEF_KEY);
        this.tfInputCols = (String[]) params.get(BaseTFSavedModelPredictParams.SELECTED_COLS);
        if (null == this.tfInputCols) {
            this.tfInputCols = tableSchema.getFieldNames();
        }
        this.inputSignatureDefs = (String[]) params.get(BaseTFSavedModelPredictParams.INPUT_SIGNATURE_DEFS);
        if (null == this.inputSignatureDefs) {
            this.inputSignatureDefs = this.tfInputCols;
        }
        AkPreconditions.checkArgument(params.contains(BaseTFSavedModelPredictParams.OUTPUT_SCHEMA_STR), (ExceptionWithErrorCode) new AkIllegalOperatorParameterException("Must set outputSchemaStr."));
        TableSchema schemaStr2Schema = TableUtil.schemaStr2Schema((String) params.get(BaseTFSavedModelPredictParams.OUTPUT_SCHEMA_STR));
        this.tfOutputCols = schemaStr2Schema.getFieldNames();
        this.outputSignatureDefs = (String[]) params.get(BaseTFSavedModelPredictParams.OUTPUT_SIGNATURE_DEFS);
        if (null == this.outputSignatureDefs) {
            this.outputSignatureDefs = this.tfOutputCols;
        }
        this.tfOutputColTypeClasses = (Class[]) Arrays.stream(schemaStr2Schema.getFieldTypes()).map((v0) -> {
            return v0.getTypeClass();
        }).toArray(i -> {
            return new Class[i];
        });
        if (params.contains(HasModelPath.MODEL_PATH)) {
            this.modelPath = (String) params.get(HasModelPath.MODEL_PATH);
        }
    }

    public BaseTFSavedModelPredictMapper setModelPath(String str) {
        this.modelPath = str;
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Map<String, Object> getPredictorConfig() {
        HashMap hashMap = new HashMap();
        Integer num = this.params.contains(BaseTFSavedModelPredictParams.INTRA_OP_PARALLELISM) ? (Integer) this.params.get(BaseTFSavedModelPredictParams.INTRA_OP_PARALLELISM) : null;
        hashMap.put("model_path", this.modelPath);
        hashMap.put(TFSavedModelConstants.GRAPH_DEF_TAG_KEY, this.graphDefTag);
        hashMap.put(TFSavedModelConstants.SIGNATURE_DEF_KEY_KEY, this.signatureDefKey);
        hashMap.put(TFSavedModelConstants.INPUT_SIGNATURE_DEFS_KEY, this.inputSignatureDefs);
        hashMap.put(TFSavedModelConstants.OUTPUT_SIGNATURE_DEFS_KEY, this.outputSignatureDefs);
        hashMap.put("output_type_classes", this.tfOutputColTypeClasses);
        hashMap.put(TFSavedModelConstants.INTRA_OP_PARALLELISM_KEY, num);
        hashMap.put(TFSavedModelConstants.INTER_OP_PARALLELISM_KEY, null);
        return hashMap;
    }

    @Override // com.alibaba.alink.common.mapper.Mapper
    public void open() {
        AkPreconditions.checkArgument(this.modelPath != null, "Model path is not set.");
        this.predictor = TFPredictorClassLoaderFactory.create(this.factory);
        this.predictor.open(getPredictorConfig());
    }

    @Override // com.alibaba.alink.common.mapper.Mapper
    public void close() {
        try {
            this.predictor.close();
        } catch (Exception e) {
            throw new AkUnclassifiedErrorException("Failed to close predictor", e);
        }
    }

    @Override // com.alibaba.alink.common.mapper.Mapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params) {
        String[] strArr = (String[]) params.get(BaseTFSavedModelPredictParams.SELECTED_COLS);
        if (null == strArr) {
            strArr = tableSchema.getFieldNames();
        }
        TableSchema schemaStr2Schema = TableUtil.schemaStr2Schema((String) params.get(BaseTFSavedModelPredictParams.OUTPUT_SCHEMA_STR));
        return Tuple4.of(strArr, schemaStr2Schema.getFieldNames(), schemaStr2Schema.getFieldTypes(), (String[]) params.get(BaseTFSavedModelPredictParams.RESERVED_COLS));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < slicedSelectedSample.length(); i++) {
            arrayList.add(slicedSelectedSample.get(i));
        }
        List<?> predict = this.predictor.predict(arrayList);
        for (int i2 = 0; i2 < slicedResult.length(); i2++) {
            slicedResult.set(i2, predict.get(i2));
        }
    }
}
