package com.alibaba.alink.operator.batch.utils;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamsIgnoredOnWebUI;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithSecondInputSpec;
import com.alibaba.alink.common.comqueue.IterTaskObjKeeper;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.io.filesystem.FilePath;
import com.alibaba.alink.common.mapper.IterableModelLoader;
import com.alibaba.alink.common.mapper.IterableModelLoaderModelMapperAdapter;
import com.alibaba.alink.common.mapper.IterableModelLoaderModelMapperAdapterMT;
import com.alibaba.alink.common.mapper.ModelBunchMapperAdapter;
import com.alibaba.alink.common.mapper.ModelBunchMapperAdapterMT;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.mapper.ModelMapperAdapter;
import com.alibaba.alink.common.mapper.ModelMapperAdapterMT;
import com.alibaba.alink.common.mapper.ModelStreamModelMapperAdapter;
import com.alibaba.alink.common.model.BroadcastVariableModelSource;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp;
import com.alibaba.alink.operator.common.modelstream.ModelStreamUtils;
import com.alibaba.alink.params.mapper.ModelMapperParams;
import com.alibaba.alink.params.shared.HasModelFilePath;
import com.alibaba.alink.params.shared.HasPredictBatchSize;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.function.TriFunction;

@InputPorts(values = {@PortSpec(value = PortType.MODEL, desc = PortDesc.PREDICT_INPUT_MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.PREDICT_INPUT_DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@Internal
@ParamsIgnoredOnWebUI(names = {"modelFilePath"})
@ReservedColsWithSecondInputSpec
/* loaded from: input_file:com/alibaba/alink/operator/batch/utils/ModelMapBatchOp.class */
public class ModelMapBatchOp<T extends ModelMapBatchOp<T>> extends BatchOperator<T> implements HasModelFilePath<T> {
    private static final String BROADCAST_MODEL_TABLE_NAME = "broadcastModelTable";
    private static final long serialVersionUID = 3479332090254995273L;
    protected final TriFunction<TableSchema, TableSchema, Params, ModelMapper> mapperBuilder;

    public ModelMapBatchOp(TriFunction<TableSchema, TableSchema, Params, ModelMapper> triFunction, Params params) {
        super(params);
        this.mapperBuilder = triFunction;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public T linkFrom(BatchOperator<?>... batchOperatorArr) {
        checkMinOpSize(1, batchOperatorArr);
        BatchOperator<?> batchOperator = batchOperatorArr.length == 2 ? batchOperatorArr[0] : null;
        BatchOperator<?> batchOperator2 = batchOperatorArr.length == 2 ? batchOperatorArr[1] : batchOperatorArr[0];
        if (batchOperator == null && getParams().get(HasModelFilePath.MODEL_FILE_PATH) != null) {
            batchOperator = (BatchOperator) new AkSourceBatchOp().setFilePath(FilePath.deserialize((String) getParams().get(HasModelFilePath.MODEL_FILE_PATH))).setMLEnvironmentId(getMLEnvironmentId());
        } else if (batchOperator == null) {
            throw new AkIllegalOperatorParameterException("One of model or modelFilePath should be set.");
        }
        try {
            ModelMapper modelMapper = (ModelMapper) this.mapperBuilder.apply(batchOperator.getSchema(), batchOperator2.getSchema(), getParams());
            setOutput(calcResultRows(batchOperator, batchOperator2, modelMapper, getParams()), modelMapper.getOutputSchema());
            return this;
        } catch (ExceptionWithErrorCode e) {
            throw e;
        } catch (Exception e2) {
            throw new AkUnclassifiedErrorException("Error. ", e2);
        }
    }

    public static DataSet<Row> calcResultRows(BatchOperator<?> batchOperator, BatchOperator<?> batchOperator2, final ModelMapper modelMapper, Params params) {
        if (modelMapper instanceof IterableModelLoader) {
            final long newHandle = IterTaskObjKeeper.getNewHandle();
            DataSet<Row> dataSet = batchOperator.getDataSet();
            MapPartitionOperator mapPartition = dataSet.flatMap(new RichFlatMapFunction<Row, Tuple2<Integer, Row>>() { // from class: com.alibaba.alink.operator.batch.utils.ModelMapBatchOp.3
                private static final long serialVersionUID = 3544759002096859673L;
                int numTask;

                public void open(Configuration configuration) {
                    this.numTask = getRuntimeContext().getNumberOfParallelSubtasks();
                }

                public void flatMap(Row row, Collector<Tuple2<Integer, Row>> collector) {
                    for (int i = 0; i < this.numTask; i++) {
                        collector.collect(Tuple2.of(Integer.valueOf(i), row));
                    }
                }

                public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                    flatMap((Row) obj, (Collector<Tuple2<Integer, Row>>) collector);
                }
            }).returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, dataSet.getType()})).partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.batch.utils.ModelMapBatchOp.2
                private static final long serialVersionUID = -2924355974935165844L;

                public int partition(Integer num, int i) {
                    return num.intValue();
                }
            }, 0).map(new MapFunction<Tuple2<Integer, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.utils.ModelMapBatchOp.1
                private static final long serialVersionUID = 8884296007768771379L;

                public Row map(Tuple2<Integer, Row> tuple2) throws Exception {
                    return (Row) tuple2.f1;
                }
            }).returns(dataSet.getType()).mapPartition(new RichMapPartitionFunction<Row, Integer>() { // from class: com.alibaba.alink.operator.batch.utils.ModelMapBatchOp.4
                private static final long serialVersionUID = 2358845952757630826L;

                public void mapPartition(Iterable<Row> iterable, Collector<Integer> collector) {
                    int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                    ((IterableModelLoader) ModelMapper.this).loadIterableModel(iterable);
                    IterTaskObjKeeper.put(newHandle, indexOfThisSubtask, ModelMapper.this);
                }
            });
            return ((Integer) params.get(ModelMapperParams.NUM_THREADS)).intValue() <= 1 ? batchOperator2.getDataSet().map(new IterableModelLoaderModelMapperAdapter(newHandle)).withBroadcastSet(mapPartition, "barrier") : batchOperator2.getDataSet().flatMap(new IterableModelLoaderModelMapperAdapterMT(newHandle, ((Integer) params.get(ModelMapperParams.NUM_THREADS)).intValue())).withBroadcastSet(mapPartition, "barrier");
        }
        final BroadcastVariableModelSource broadcastVariableModelSource = new BroadcastVariableModelSource(BROADCAST_MODEL_TABLE_NAME);
        PartitionOperator rebalance = batchOperator.getDataSet().rebalance();
        if (ModelStreamUtils.useModelStreamFile(params)) {
            return batchOperator2.getDataSet().map(new RichMapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.utils.ModelMapBatchOp.5
                ModelStreamModelMapperAdapter modelStreamModelMapper;

                public void open(Configuration configuration) throws Exception {
                    super.open(configuration);
                    modelMapper.loadModel(BroadcastVariableModelSource.this.getModelRows(getRuntimeContext()));
                    modelMapper.open();
                    this.modelStreamModelMapper = new ModelStreamModelMapperAdapter(modelMapper);
                }

                public Row map(Row row) throws Exception {
                    return this.modelStreamModelMapper.map(row);
                }
            }).withBroadcastSet(rebalance, BROADCAST_MODEL_TABLE_NAME);
        }
        if (params.get(HasPredictBatchSize.PREDICT_BATCH_SIZE) == null) {
            return ((Integer) params.get(ModelMapperParams.NUM_THREADS)).intValue() <= 1 ? batchOperator2.getDataSet().map(new ModelMapperAdapter(modelMapper, broadcastVariableModelSource)).withBroadcastSet(rebalance, BROADCAST_MODEL_TABLE_NAME) : batchOperator2.getDataSet().flatMap(new ModelMapperAdapterMT(modelMapper, broadcastVariableModelSource, ((Integer) params.get(ModelMapperParams.NUM_THREADS)).intValue())).withBroadcastSet(rebalance, BROADCAST_MODEL_TABLE_NAME);
        }
        int intValue = ((Integer) params.get(HasPredictBatchSize.PREDICT_BATCH_SIZE)).intValue();
        if (intValue <= 0) {
            throw new AkIllegalOperatorParameterException("batch size must larger than 0.");
        }
        return ((Integer) params.get(ModelMapperParams.NUM_THREADS)).intValue() <= 1 ? batchOperator2.getDataSet().mapPartition(new ModelBunchMapperAdapter(modelMapper, broadcastVariableModelSource, intValue)).withBroadcastSet(rebalance, BROADCAST_MODEL_TABLE_NAME) : batchOperator2.getDataSet().mapPartition(new ModelBunchMapperAdapterMT(modelMapper, broadcastVariableModelSource, ((Integer) params.get(ModelMapperParams.NUM_THREADS)).intValue(), intValue)).withBroadcastSet(rebalance, BROADCAST_MODEL_TABLE_NAME);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ BatchOperator linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
