package com.alibaba.alink.operator.common.tensorflow;

import com.alibaba.alink.common.dl.plugin.TFPredictorClassLoaderFactory;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.params.tensorflow.savedmodel.HasOutputBatchAxes;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/BaseTFSavedModelPredictRowMapper.class */
public class BaseTFSavedModelPredictRowMapper extends BaseTFSavedModelPredictMapper {
    private final int[] outputBatchAxes;

    public BaseTFSavedModelPredictRowMapper(TableSchema tableSchema, Params params) {
        this(tableSchema, params, new TFPredictorClassLoaderFactory());
    }

    public BaseTFSavedModelPredictRowMapper(TableSchema tableSchema, Params params, TFPredictorClassLoaderFactory tFPredictorClassLoaderFactory) {
        super(tableSchema, params, tFPredictorClassLoaderFactory);
        this.outputBatchAxes = params.contains(HasOutputBatchAxes.OUTPUT_BATCH_AXES) ? (int[]) params.get(HasOutputBatchAxes.OUTPUT_BATCH_AXES) : new int[this.tfOutputCols.length];
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.operator.common.tensorflow.BaseTFSavedModelPredictMapper
    public Map<String, Object> getPredictorConfig() {
        Map<String, Object> predictorConfig = super.getPredictorConfig();
        predictorConfig.put(TFSavedModelConstants.OUTPUT_BATCH_AXES, this.outputBatchAxes);
        return predictorConfig;
    }

    @Override // com.alibaba.alink.operator.common.tensorflow.BaseTFSavedModelPredictMapper, com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < slicedSelectedSample.length(); i++) {
            arrayList.add(slicedSelectedSample.get(i));
        }
        List<?> predictRow = this.predictor.predictRow(arrayList);
        for (int i2 = 0; i2 < slicedResult.length(); i2++) {
            slicedResult.set(i2, predictRow.get(i2));
        }
    }
}
