package com.alibaba.alink.operator.common.pytorch;

import com.alibaba.alink.common.dl.plugin.DLPredictServiceMapper;
import com.alibaba.alink.common.dl.plugin.TorchPredictorClassLoaderFactory;
import com.alibaba.alink.common.io.plugin.ResourcePluginFactory;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.dl.TorchModelPredictParams;
import java.io.File;
import java.util.Arrays;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/pytorch/TorchModelPredictMapper.class */
public class TorchModelPredictMapper extends DLPredictServiceMapper<TorchPredictorClassLoaderFactory> {
    private static final String TORCH_JAVA_VERSION = "1.8.0r1";
    private static final String LIBTORCH_VERSION = "1.8.1";
    private final int intraOpParallelism;
    private Class<?>[] outputColTypeClasses;
    private String libraryPath;
    private final ResourcePluginFactory resourceFactory;

    public TorchModelPredictMapper(TableSchema tableSchema, Params params) {
        this(tableSchema, params, new TorchPredictorClassLoaderFactory(TORCH_JAVA_VERSION));
    }

    public TorchModelPredictMapper(TableSchema tableSchema, Params params, TorchPredictorClassLoaderFactory torchPredictorClassLoaderFactory) {
        super(tableSchema, params, torchPredictorClassLoaderFactory, false);
        this.intraOpParallelism = ((Integer) params.get(TorchModelPredictParams.INTRA_OP_PARALLELISM)).intValue();
        this.resourceFactory = new ResourcePluginFactory();
    }

    @Override // com.alibaba.alink.common.dl.plugin.DLPredictServiceMapper, com.alibaba.alink.common.mapper.Mapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params) {
        String[] strArr = (String[]) params.get(TorchModelPredictParams.SELECTED_COLS);
        TableSchema schemaStr2Schema = TableUtil.schemaStr2Schema((String) params.get(TorchModelPredictParams.OUTPUT_SCHEMA_STR));
        return Tuple4.of(strArr, schemaStr2Schema.getFieldNames(), schemaStr2Schema.getFieldTypes(), (String[]) params.get(TorchModelPredictParams.RESERVED_COLS));
    }

    @Override // com.alibaba.alink.common.dl.plugin.DLPredictServiceMapper
    public DLPredictServiceMapper.PredictorConfig getPredictorConfig() {
        DLPredictServiceMapper.PredictorConfig predictorConfig = new DLPredictServiceMapper.PredictorConfig();
        predictorConfig.factory = this.factory;
        predictorConfig.modelPath = this.localModelPath;
        predictorConfig.libraryPath = this.libraryPath;
        predictorConfig.outputTypeClasses = this.outputColTypeClasses;
        predictorConfig.intraOpNumThreads = Integer.valueOf(this.intraOpParallelism);
        predictorConfig.threadMode = false;
        return predictorConfig;
    }

    @Override // com.alibaba.alink.common.dl.plugin.DLPredictServiceMapper, com.alibaba.alink.common.mapper.Mapper
    public void open() {
        this.outputColTypeClasses = (Class[]) Arrays.stream(TableUtil.schemaStr2Schema((String) this.params.get(TorchModelPredictParams.OUTPUT_SCHEMA_STR)).getFieldTypes()).map((v0) -> {
            return v0.getTypeClass();
        }).toArray(i -> {
            return new Class[i];
        });
        this.libraryPath = new File(LibtorchUtils.getLibtorchPath(this.resourceFactory, LIBTORCH_VERSION), "lib").getAbsolutePath();
        super.open();
    }
}
