package com.alibaba.alink.operator.common.tensorflow;

import com.alibaba.alink.common.dl.plugin.TFPredictorClassLoaderFactory;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.tensorflow.savedmodel.TFTableModelPredictParams;
import java.io.Serializable;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/TFTableModelPredictModelMapper.class */
public class TFTableModelPredictModelMapper extends ModelMapper implements Serializable {
    private final BaseTFSavedModelPredictRowMapper mapper;

    public TFTableModelPredictModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        this(tableSchema, tableSchema2, params, new TFPredictorClassLoaderFactory());
    }

    public TFTableModelPredictModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params, TFPredictorClassLoaderFactory tFPredictorClassLoaderFactory) {
        super(tableSchema, tableSchema2, params);
        this.mapper = new BaseTFSavedModelPredictRowMapper(tableSchema2, params, tFPredictorClassLoaderFactory);
    }

    @Override // com.alibaba.alink.common.mapper.Mapper
    public void open() {
        this.mapper.open();
    }

    @Override // com.alibaba.alink.common.mapper.Mapper
    public void close() {
        this.mapper.close();
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.mapper.setModelPath(TFSavedModelUtils.loadSavedModelFromRows(list));
    }

    public void loadModel(Iterable<Row> iterable) {
        this.mapper.setModelPath(TFSavedModelUtils.loadSavedModelFromRows(iterable));
    }

    public void loadModelFromZipFile(String str) {
        this.mapper.setModelPath(TFSavedModelUtils.loadSavedModelFromZipFile(str));
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public ModelMapper createNew(List<Row> list) {
        return super.createNew(list);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        String[] strArr = (String[]) params.get(TFTableModelPredictParams.SELECTED_COLS);
        if (null == strArr) {
            strArr = tableSchema2.getFieldNames();
        }
        TableSchema schemaStr2Schema = TableUtil.schemaStr2Schema((String) params.get(TFTableModelPredictParams.OUTPUT_SCHEMA_STR));
        return Tuple4.of(strArr, schemaStr2Schema.getFieldNames(), schemaStr2Schema.getFieldTypes(), (String[]) params.get(TFTableModelPredictParams.RESERVED_COLS));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        this.mapper.map(slicedSelectedSample, slicedResult);
    }
}
