package com.alibaba.alink.operator.common.tensorflow;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.dl.plugin.DLPredictorService;
import com.alibaba.alink.common.dl.plugin.TFPredictorClassLoaderFactory;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkPluginErrorException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.mapper.FlatMapper;
import com.alibaba.alink.common.utils.OutputColsHelper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.dl.HasInferBatchSizeDefaultAs256;
import com.alibaba.alink.params.dl.HasModelPath;
import com.alibaba.alink.params.tensorflow.savedmodel.BaseTFSavedModelPredictParams;
import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/BaseTFSavedModelPredictRowFlatMapper.class */
public class BaseTFSavedModelPredictRowFlatMapper extends FlatMapper implements Serializable {
    private static final Logger LOG = LoggerFactory.getLogger(BaseTFSavedModelPredictRowFlatMapper.class);
    private static final long QUEUE_OFFER_TIMEOUT_MS = 50;
    private final TFPredictorClassLoaderFactory factory;
    private final String graphDefTag;
    private final String signatureDefKey;
    private final int[] tfInputColIds;
    private final String[] tfOutputCols;
    private final Class<?>[] tfOutputColTypeClasses;
    private final OutputColsHelper outputColsHelper;
    private String[] inputSignatureDefs;
    private String[] outputSignatureDefs;
    private String modelPath;
    private int batchSize;
    private AtomicBoolean stopInference;
    private ArrayBlockingQueue<Pair<Row, Collector<Row>>> queue;
    private ExecutorService executorService;
    private Future<?> inferenceRunnerFuture;
    private DLPredictorService predictor;

    /* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/BaseTFSavedModelPredictRowFlatMapper$InferenceRunner.class */
    private class InferenceRunner implements Runnable {
        private InferenceRunner() {
        }

        @Override // java.lang.Runnable
        public void run() {
            ArrayList arrayList = new ArrayList();
            while (true) {
                if (BaseTFSavedModelPredictRowFlatMapper.this.stopInference.get() && BaseTFSavedModelPredictRowFlatMapper.this.queue.isEmpty()) {
                    break;
                }
                Pair<Row, Collector<Row>> pair = null;
                try {
                    pair = (Pair) BaseTFSavedModelPredictRowFlatMapper.this.queue.poll(1L, TimeUnit.SECONDS);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                if (null != pair) {
                    arrayList.add(pair);
                }
                if (arrayList.size() == BaseTFSavedModelPredictRowFlatMapper.this.batchSize) {
                    processRows(arrayList);
                    arrayList.clear();
                }
            }
            if (arrayList.size() > 0) {
                processRows(arrayList);
                arrayList.clear();
            }
        }

        public void processRows(List<Pair<Row, Collector<Row>>> list) {
            long currentTimeMillis = System.currentTimeMillis();
            int size = list.size();
            ArrayList arrayList = new ArrayList();
            for (int i : BaseTFSavedModelPredictRowFlatMapper.this.tfInputColIds) {
                ArrayList arrayList2 = new ArrayList();
                Iterator<Pair<Row, Collector<Row>>> it = list.iterator();
                while (it.hasNext()) {
                    arrayList2.add(((Row) it.next().getLeft()).getField(i));
                }
                arrayList.add(arrayList2);
            }
            List<List<?>> predictRows = BaseTFSavedModelPredictRowFlatMapper.this.predictor.predictRows(arrayList, size);
            Row[] rowArr = (Row[]) IntStream.range(0, size).mapToObj(i2 -> {
                return new Row(BaseTFSavedModelPredictRowFlatMapper.this.tfOutputCols.length);
            }).toArray(i3 -> {
                return new Row[i3];
            });
            for (int i4 = 0; i4 < predictRows.size(); i4++) {
                List<?> list2 = predictRows.get(i4);
                for (int i5 = 0; i5 < size; i5++) {
                    rowArr[i5].setField(i4, list2.get(i5));
                }
            }
            for (int i6 = 0; i6 < size; i6++) {
                Pair<Row, Collector<Row>> pair = list.get(i6);
                ((Collector) pair.getRight()).collect(BaseTFSavedModelPredictRowFlatMapper.this.outputColsHelper.getResultRow((Row) pair.getLeft(), rowArr[i6]));
            }
            BaseTFSavedModelPredictRowFlatMapper.LOG.info("{} items cost {} ms.", Integer.valueOf(size), Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.printf("%s items cost %s ms.%n", Integer.valueOf(size), Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
            }
        }
    }

    public BaseTFSavedModelPredictRowFlatMapper(TableSchema tableSchema, Params params) {
        this(tableSchema, params, new TFPredictorClassLoaderFactory());
    }

    public BaseTFSavedModelPredictRowFlatMapper(TableSchema tableSchema, Params params, TFPredictorClassLoaderFactory tFPredictorClassLoaderFactory) {
        super(tableSchema, params);
        this.batchSize = 256;
        this.factory = tFPredictorClassLoaderFactory;
        this.graphDefTag = (String) params.get(BaseTFSavedModelPredictParams.GRAPH_DEF_TAG);
        this.signatureDefKey = (String) params.get(BaseTFSavedModelPredictParams.SIGNATURE_DEF_KEY);
        String[] strArr = (String[]) params.get(BaseTFSavedModelPredictParams.SELECTED_COLS);
        strArr = null == strArr ? tableSchema.getFieldNames() : strArr;
        this.tfInputColIds = TableUtil.findColIndicesWithAssertAndHint(tableSchema, strArr);
        this.inputSignatureDefs = (String[]) params.get(BaseTFSavedModelPredictParams.INPUT_SIGNATURE_DEFS);
        if (null == this.inputSignatureDefs) {
            this.inputSignatureDefs = strArr;
        }
        AkPreconditions.checkArgument(params.contains(BaseTFSavedModelPredictParams.OUTPUT_SCHEMA_STR), (ExceptionWithErrorCode) new AkIllegalOperatorParameterException("Must set outputSchemaStr."));
        TableSchema schemaStr2Schema = TableUtil.schemaStr2Schema((String) params.get(BaseTFSavedModelPredictParams.OUTPUT_SCHEMA_STR));
        this.tfOutputCols = schemaStr2Schema.getFieldNames();
        this.outputSignatureDefs = (String[]) params.get(BaseTFSavedModelPredictParams.OUTPUT_SIGNATURE_DEFS);
        if (null == this.outputSignatureDefs) {
            this.outputSignatureDefs = this.tfOutputCols;
        }
        TypeInformation[] fieldTypes = schemaStr2Schema.getFieldTypes();
        this.tfOutputColTypeClasses = (Class[]) Arrays.stream(fieldTypes).map((v0) -> {
            return v0.getTypeClass();
        }).toArray(i -> {
            return new Class[i];
        });
        this.outputColsHelper = new OutputColsHelper(tableSchema, this.tfOutputCols, (TypeInformation<?>[]) fieldTypes, (String[]) params.get(BaseTFSavedModelPredictParams.RESERVED_COLS));
        if (params.contains(HasModelPath.MODEL_PATH)) {
            this.modelPath = (String) params.get(HasModelPath.MODEL_PATH);
        }
        if (params.contains(HasInferBatchSizeDefaultAs256.INFER_BATCH_SIZE)) {
            this.batchSize = ((Integer) params.get(HasInferBatchSizeDefaultAs256.INFER_BATCH_SIZE)).intValue();
        }
    }

    public BaseTFSavedModelPredictRowFlatMapper setModelPath(String str) {
        this.modelPath = str;
        return this;
    }

    @Override // com.alibaba.alink.common.mapper.FlatMapper
    public TableSchema getOutputSchema() {
        return this.outputColsHelper.getResultSchema();
    }

    @Override // com.alibaba.alink.common.mapper.FlatMapper
    public void open() {
        AkPreconditions.checkArgument(this.modelPath != null, "Model path is not set.");
        Integer num = this.params.contains(BaseTFSavedModelPredictParams.INTRA_OP_PARALLELISM) ? (Integer) this.params.get(BaseTFSavedModelPredictParams.INTRA_OP_PARALLELISM) : null;
        try {
            this.predictor = (DLPredictorService) this.factory.create().loadClass("com.alibaba.alink.common.dl.plugin.TFPredictorServiceImpl").getConstructor(new Class[0]).newInstance(new Object[0]);
            HashMap hashMap = new HashMap();
            hashMap.put("model_path", this.modelPath);
            hashMap.put(TFSavedModelConstants.GRAPH_DEF_TAG_KEY, this.graphDefTag);
            hashMap.put(TFSavedModelConstants.SIGNATURE_DEF_KEY_KEY, this.signatureDefKey);
            hashMap.put(TFSavedModelConstants.INPUT_SIGNATURE_DEFS_KEY, this.inputSignatureDefs);
            hashMap.put(TFSavedModelConstants.OUTPUT_SIGNATURE_DEFS_KEY, this.outputSignatureDefs);
            hashMap.put("output_type_classes", this.tfOutputColTypeClasses);
            hashMap.put(TFSavedModelConstants.INTRA_OP_PARALLELISM_KEY, num);
            hashMap.put(TFSavedModelConstants.INTER_OP_PARALLELISM_KEY, null);
            this.predictor.open(hashMap);
            this.queue = new ArrayBlockingQueue<>(this.batchSize);
            this.stopInference = new AtomicBoolean(false);
            this.executorService = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingDeque());
            this.inferenceRunnerFuture = this.executorService.submit(new InferenceRunner());
        } catch (ClassNotFoundException | IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException e) {
            throw new AkPluginErrorException("Failed to create TFPredictorServiceImpl instance.", e);
        }
    }

    @Override // com.alibaba.alink.common.mapper.FlatMapper
    public void close() {
        this.stopInference.set(true);
        try {
            this.inferenceRunnerFuture.get();
            this.executorService.shutdown();
            try {
                this.predictor.close();
            } catch (Exception e) {
                throw new AkUnclassifiedErrorException("Failed to close predictor.", e);
            }
        } catch (InterruptedException | ExecutionException e2) {
            throw new AkUnclassifiedErrorException("Inference runner is failed or interrupted.", e2);
        }
    }

    @Override // com.alibaba.alink.common.mapper.FlatMapper
    public void flatMap(Row row, Collector<Row> collector) throws Exception {
        while (!this.queue.offer(Pair.of(row, collector), QUEUE_OFFER_TIMEOUT_MS, TimeUnit.MILLISECONDS)) {
            try {
                this.inferenceRunnerFuture.get(0L, TimeUnit.MILLISECONDS);
            } catch (InterruptedException | ExecutionException e) {
                throw new AkUnclassifiedErrorException("Inference runner is failed or interrupted.", e);
            } catch (TimeoutException e2) {
            }
        }
    }
}
