package com.alibaba.alink.operator.common.regression.tensorflow;

import com.alibaba.alink.common.dl.plugin.TFPredictorClassLoaderFactory;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.linalg.tensor.FloatTensor;
import com.alibaba.alink.common.mapper.FlatModelMapper;
import com.alibaba.alink.common.mapper.IterableModelLoader;
import com.alibaba.alink.common.model.LabeledModelDataConverter;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.tensorflow.CachedRichModelMapper;
import com.alibaba.alink.operator.common.tensorflow.TFModelDataConverterUtils;
import com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictFlatModelMapper;
import com.alibaba.alink.params.dl.HasInferBatchSizeDefaultAs256;
import com.alibaba.alink.params.regression.TFTableModelRegressionPredictParams;
import com.alibaba.alink.params.tensorflow.savedmodel.TFTableModelPredictParams;
import com.alibaba.alink.pipeline.LocalPredictor;
import com.alibaba.alink.pipeline.LocalPredictorLoader;
import java.util.List;
import org.apache.commons.collections.CollectionUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/regression/tensorflow/TFTableModelRegressionFlatModelMapper.class */
public class TFTableModelRegressionFlatModelMapper extends CachedRichModelMapper implements IterableModelLoader {
    private LocalPredictor preprocessLocalPredictor;
    private TFTableModelPredictFlatModelMapper tfFlatModelMapper;
    private final TypeInformation<?> labelType;
    private final String predCol;
    private int predColId;
    private final TFPredictorClassLoaderFactory factory;

    public TFTableModelRegressionFlatModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.preprocessLocalPredictor = null;
        this.predCol = (String) params.get(TFTableModelRegressionPredictParams.PREDICTION_COL);
        this.labelType = LabeledModelDataConverter.extractLabelType(tableSchema);
        this.factory = new TFPredictorClassLoaderFactory();
    }

    @Override // com.alibaba.alink.common.mapper.FlatModelMapper, com.alibaba.alink.common.mapper.FlatMapper
    public void open() {
        this.tfFlatModelMapper.open();
    }

    @Override // com.alibaba.alink.common.mapper.FlatModelMapper, com.alibaba.alink.common.mapper.FlatMapper
    public void close() {
        this.tfFlatModelMapper.close();
    }

    @Override // com.alibaba.alink.common.mapper.FlatModelMapper
    public void loadModel(List<Row> list) {
        TFTableModelRegressionModelDataConverter tFTableModelRegressionModelDataConverter = new TFTableModelRegressionModelDataConverter(this.labelType);
        loadFromModelData(tFTableModelRegressionModelDataConverter.load(list), tFTableModelRegressionModelDataConverter.getModelSchema());
    }

    @Override // com.alibaba.alink.common.mapper.IterableModelLoader
    public void loadIterableModel(Iterable<Row> iterable) {
        TFTableModelRegressionModelDataConverter tFTableModelRegressionModelDataConverter = new TFTableModelRegressionModelDataConverter(this.labelType);
        loadFromModelData(tFTableModelRegressionModelDataConverter.loadIterable(iterable), tFTableModelRegressionModelDataConverter.getModelSchema());
    }

    protected void loadFromModelData(TFTableModelRegressionModelData tFTableModelRegressionModelData, TableSchema tableSchema) {
        Params meta = tFTableModelRegressionModelData.getMeta();
        String str = (String) meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_DEF);
        TypeInformation<FloatTensor> typeInformation = AlinkTypes.FLOAT_TENSOR;
        TableSchema dataSchema = getDataSchema();
        if (CollectionUtils.isNotEmpty(tFTableModelRegressionModelData.getPreprocessPipelineModelRows())) {
            try {
                this.preprocessLocalPredictor = LocalPredictorLoader.load(tFTableModelRegressionModelData.getPreprocessPipelineModelRows(), TableUtil.schemaStr2Schema(tFTableModelRegressionModelData.getPreprocessPipelineModelSchemaStr()), dataSchema);
                dataSchema = this.preprocessLocalPredictor.getOutputSchema();
            } catch (Exception e) {
                throw new AkUnclassifiedErrorException("Cannot initialize preprocess PipelineModel", e);
            }
        }
        String[] strArr = (String[]) meta.get(TFModelDataConverterUtils.TF_INPUT_COLS);
        Params params = new Params();
        params.set((ParamInfo<ParamInfo<String[]>>) TFTableModelPredictParams.RESERVED_COLS, (ParamInfo<String[]>) new String[0]);
        params.set((ParamInfo<ParamInfo<String[]>>) TFTableModelPredictParams.OUTPUT_SIGNATURE_DEFS, (ParamInfo<String[]>) new String[]{str});
        params.set((ParamInfo<ParamInfo<String>>) TFTableModelPredictParams.OUTPUT_SCHEMA_STR, (ParamInfo<String>) TableUtil.schema2SchemaStr(TableSchema.builder().field(this.predCol, typeInformation).build()));
        params.set((ParamInfo<ParamInfo<String[]>>) TFTableModelPredictParams.SELECTED_COLS, (ParamInfo<String[]>) strArr);
        if (meta.contains(HasInferBatchSizeDefaultAs256.INFER_BATCH_SIZE)) {
            params.set((ParamInfo<ParamInfo<Integer>>) HasInferBatchSizeDefaultAs256.INFER_BATCH_SIZE, (ParamInfo<Integer>) meta.get(HasInferBatchSizeDefaultAs256.INFER_BATCH_SIZE));
        }
        params.set((ParamInfo<ParamInfo<Integer>>) HasInferBatchSizeDefaultAs256.INFER_BATCH_SIZE, (ParamInfo<Integer>) this.params.get(HasInferBatchSizeDefaultAs256.INFER_BATCH_SIZE));
        this.tfFlatModelMapper = new TFTableModelPredictFlatModelMapper(tableSchema, dataSchema, params, this.factory);
        if (null != tFTableModelRegressionModelData.getTfModelZipPath()) {
            this.tfFlatModelMapper.loadModelFromZipFile(tFTableModelRegressionModelData.getTfModelZipPath());
        } else {
            this.tfFlatModelMapper.loadModel(tFTableModelRegressionModelData.getTfModelRows());
        }
        this.predColId = TableUtil.findColIndex(this.tfFlatModelMapper.getOutputSchema(), this.predCol);
    }

    @Override // com.alibaba.alink.common.mapper.FlatModelMapper
    public FlatModelMapper createNew(List<Row> list) {
        this.tfFlatModelMapper.loadModel(list);
        return this;
    }

    @Override // com.alibaba.alink.operator.common.tensorflow.CachedRichModelMapper, com.alibaba.alink.common.mapper.FlatMapper
    public void flatMap(Row row, Collector<Row> collector) throws Exception {
        CachedRichModelMapper.PredictionCollector predictionCollector = new CachedRichModelMapper.PredictionCollector(Row.copy(row), collector);
        if (null != this.preprocessLocalPredictor) {
            row = this.preprocessLocalPredictor.map(row);
        }
        this.tfFlatModelMapper.flatMap(row, predictionCollector);
    }

    @Override // com.alibaba.alink.operator.common.tensorflow.CachedRichModelMapper
    protected Object extractPredictResult(Row row) throws Exception {
        FloatTensor floatTensor = (FloatTensor) row.getField(this.predColId);
        AkPreconditions.checkState(floatTensor.size() == 1, "The prediction tensor must have size 1");
        return Double.valueOf(floatTensor.shape().length == 1 ? floatTensor.getFloat(0) : floatTensor.getFloat(new long[0]));
    }

    @Override // com.alibaba.alink.operator.common.tensorflow.CachedRichModelMapper
    protected Tuple2<Object, String> extractPredictResultDetail(Row row) throws Exception {
        throw new UnsupportedOperationException("Not supported predict with details in TFTableModelRegressionFlatModelMapper");
    }
}
