package com.alibaba.alink.operator.common.onnx;

import com.alibaba.alink.common.dl.plugin.DLPredictServiceMapper;
import com.alibaba.alink.common.dl.plugin.DLPredictorService;
import com.alibaba.alink.common.dl.plugin.OnnxPredictorClassLoaderFactory;
import com.alibaba.alink.params.dl.HasIntraOpParallelism;
import com.alibaba.alink.params.onnx.OnnxModelPredictParams;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/onnx/OnnxModelPredictMapper.class */
public class OnnxModelPredictMapper extends DLPredictServiceMapper<OnnxPredictorClassLoaderFactory> {
    protected DLPredictorService predictor;
    private String[] inputNames;
    private String[] outputNames;

    public OnnxModelPredictMapper(TableSchema tableSchema, Params params) {
        this(tableSchema, params, new OnnxPredictorClassLoaderFactory());
    }

    public OnnxModelPredictMapper(TableSchema tableSchema, Params params, OnnxPredictorClassLoaderFactory onnxPredictorClassLoaderFactory) {
        super(tableSchema, params, onnxPredictorClassLoaderFactory, true);
        this.inputNames = (String[]) params.get(OnnxModelPredictParams.INPUT_NAMES);
        if (null == this.inputNames) {
            this.inputNames = this.inputCols;
        }
        this.outputNames = (String[]) params.get(OnnxModelPredictParams.OUTPUT_NAMES);
        if (null == this.outputNames) {
            this.outputNames = this.outputCols;
        }
    }

    @Override // com.alibaba.alink.common.dl.plugin.DLPredictServiceMapper
    protected DLPredictServiceMapper.PredictorConfig getPredictorConfig() {
        DLPredictServiceMapper.PredictorConfig predictorConfig = new DLPredictServiceMapper.PredictorConfig();
        predictorConfig.factory = this.factory;
        Integer defaultValue = this.params.contains(HasIntraOpParallelism.INTRA_OP_PARALLELISM) ? (Integer) this.params.get(HasIntraOpParallelism.INTRA_OP_PARALLELISM) : HasIntraOpParallelism.INTRA_OP_PARALLELISM.getDefaultValue();
        predictorConfig.modelPath = this.localModelPath;
        predictorConfig.inputNames = this.inputNames;
        predictorConfig.outputNames = this.outputNames;
        predictorConfig.outputTypeClasses = this.outputColTypeClasses;
        predictorConfig.intraOpNumThreads = defaultValue;
        predictorConfig.threadMode = false;
        return predictorConfig;
    }
}
