package com.alibaba.alink.operator.stream.utils;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithSecondInputSpec;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.io.directreader.DataBridge;
import com.alibaba.alink.common.io.directreader.DirectReader;
import com.alibaba.alink.common.io.filesystem.AkUtils;
import com.alibaba.alink.common.io.filesystem.FilePath;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.mapper.ModelMapperAdapter;
import com.alibaba.alink.common.mapper.ModelMapperAdapterMT;
import com.alibaba.alink.common.model.DataBridgeModelSource;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
import com.alibaba.alink.operator.common.modelstream.ModelStreamUtils;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.source.ModelStreamFileSourceStreamOp;
import com.alibaba.alink.operator.stream.utils.ModelMapStreamOp;
import com.alibaba.alink.params.ModelStreamScanParams;
import com.alibaba.alink.params.io.AkSourceParams;
import com.alibaba.alink.params.io.ModelFileSinkParams;
import com.alibaba.alink.params.mapper.ModelMapperParams;
import com.alibaba.alink.params.shared.HasModelFilePath;
import java.io.IOException;
import java.util.List;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.function.TriFunction;

@InputPorts(values = {@PortSpec(value = PortType.MODEL, opType = PortSpec.OpType.BATCH, desc = PortDesc.PREDICT_INPUT_MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.PREDICT_INPUT_DATA), @PortSpec(value = PortType.MODEL_STREAM, isOptional = true, desc = PortDesc.PREDICT_INPUT_MODEL_STREAM)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@Internal
@ReservedColsWithSecondInputSpec
/* loaded from: input_file:com/alibaba/alink/operator/stream/utils/ModelMapStreamOp.class */
public class ModelMapStreamOp<T extends ModelMapStreamOp<T>> extends StreamOperator<T> implements ModelStreamScanParams<T>, HasModelFilePath<T> {
    private static final long serialVersionUID = -6591412871091394859L;
    protected final BatchOperator<?> model;
    protected final TriFunction<TableSchema, TableSchema, Params, ModelMapper> mapperBuilder;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/stream/utils/ModelMapStreamOp$FilePathDataBridge.class */
    public static final class FilePathDataBridge implements DataBridge {
        private final FilePath filePath;

        private FilePathDataBridge(FilePath filePath) {
            this.filePath = filePath;
        }

        @Override // com.alibaba.alink.common.io.directreader.DataBridge
        public List<Row> read(FilterFunction<Row> filterFunction) {
            try {
                return (List) AkUtils.readFromPath(this.filePath, filterFunction).f1;
            } catch (Exception e) {
                throw new AkUnclassifiedErrorException("Error. ", e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/stream/utils/ModelMapStreamOp$MTableDataBridge.class */
    public static class MTableDataBridge implements DataBridge {
        private final MTable mt;

        public MTableDataBridge(MTable mTable) {
            this.mt = mTable;
        }

        @Override // com.alibaba.alink.common.io.directreader.DataBridge
        public List<Row> read(FilterFunction<Row> filterFunction) {
            return this.mt.getRows();
        }
    }

    public ModelMapStreamOp(TriFunction<TableSchema, TableSchema, Params, ModelMapper> triFunction, Params params) {
        super(params);
        this.model = null;
        this.mapperBuilder = triFunction;
    }

    public ModelMapStreamOp(BatchOperator<?> batchOperator, TriFunction<TableSchema, TableSchema, Params, ModelMapper> triFunction, Params params) {
        super(params);
        this.model = batchOperator;
        this.mapperBuilder = triFunction;
    }

    public static DataStream<Row> broadcastStream(DataStream<Row> dataStream) {
        return dataStream.flatMap(new RichFlatMapFunction<Row, Tuple2<Integer, Row>>() { // from class: com.alibaba.alink.operator.stream.utils.ModelMapStreamOp.3
            private static final long serialVersionUID = 6421400378693673120L;

            public void flatMap(Row row, Collector<Tuple2<Integer, Row>> collector) throws Exception {
                int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                for (int i = 0; i < numberOfParallelSubtasks; i++) {
                    collector.collect(Tuple2.of(Integer.valueOf(i), row));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Tuple2<Integer, Row>>) collector);
            }
        }).partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.stream.utils.ModelMapStreamOp.2
            public int partition(Integer num, int i) {
                return num.intValue();
            }
        }, 0).map(new MapFunction<Tuple2<Integer, Row>, Row>() { // from class: com.alibaba.alink.operator.stream.utils.ModelMapStreamOp.1
            public Row map(Tuple2<Integer, Row> tuple2) throws Exception {
                return (Row) tuple2.f1;
            }
        });
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public T linkFrom(StreamOperator<?>... streamOperatorArr) {
        checkMinOpSize(1, streamOperatorArr);
        StreamOperator<?> streamOperator = streamOperatorArr[0];
        StreamOperator<?> streamOperator2 = streamOperatorArr.length > 1 ? streamOperatorArr[1] : null;
        try {
            Tuple2<DataBridge, TableSchema> createDataBridge = createDataBridge((String) getParams().get(ModelFileSinkParams.MODEL_FILE_PATH), this.model);
            ModelMapper modelMapper = (ModelMapper) this.mapperBuilder.apply(createDataBridge.f1, streamOperator.getSchema(), getParams());
            setOutput(calcResultRows((DataBridge) createDataBridge.f0, (TableSchema) createDataBridge.f1, streamOperator, streamOperator2, modelMapper, getParams(), getMLEnvironmentId(), this.mapperBuilder), modelMapper.getOutputSchema());
            return this;
        } catch (Exception e) {
            throw new AkUnclassifiedErrorException(e.getMessage(), e);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static DataStream<Row> calcResultRows(DataBridge dataBridge, TableSchema tableSchema, StreamOperator<?> streamOperator, StreamOperator<?> streamOperator2, ModelMapper modelMapper, Params params, Long l, TriFunction<TableSchema, TableSchema, Params, ModelMapper> triFunction) {
        DataBridgeModelSource dataBridgeModelSource = new DataBridgeModelSource(dataBridge);
        DataStream<Row> dataStream = null;
        TableSchema tableSchema2 = null;
        if (ModelStreamUtils.useModelStreamFile(params)) {
            StreamOperator streamOperator3 = (StreamOperator) new ModelStreamFileSourceStreamOp().setFilePath(FilePath.deserialize((String) params.get(ModelStreamScanParams.MODEL_STREAM_FILE_PATH))).setScanInterval((Integer) params.get(ModelStreamScanParams.MODEL_STREAM_SCAN_INTERVAL)).setStartTime((String) params.get(ModelStreamScanParams.MODEL_STREAM_START_TIME)).setSchemaStr(TableUtil.schema2SchemaStr(tableSchema)).setMLEnvironmentId(l);
            tableSchema2 = streamOperator3.getSchema();
            dataStream = streamOperator3.getDataStream();
        }
        if (null != streamOperator2) {
            if (dataStream == null) {
                tableSchema2 = streamOperator2.getSchema();
                dataStream = streamOperator2.getDataStream();
            } else {
                dataStream = dataStream.union(new DataStream[]{streamOperator2.select(tableSchema2.getFieldNames()).getDataStream()});
            }
        }
        return dataStream != null ? streamOperator.getDataStream().connect(broadcastStream(dataStream)).flatMap(new PredictProcess(tableSchema, streamOperator.getSchema(), params, triFunction, dataBridge, ModelStreamUtils.findTimestampColIndexWithAssertAndHint(tableSchema2), ModelStreamUtils.findCountColIndexWithAssertAndHint(tableSchema2))) : ((Integer) params.get(ModelMapperParams.NUM_THREADS)).intValue() <= 1 ? streamOperator.getDataStream().map(new ModelMapperAdapter(modelMapper, dataBridgeModelSource)) : streamOperator.getDataStream().flatMap(new ModelMapperAdapterMT(modelMapper, dataBridgeModelSource, ((Integer) params.get(ModelMapperParams.NUM_THREADS)).intValue()));
    }

    public static Tuple2<DataBridge, TableSchema> createDataBridge(String str, BatchOperator<?> batchOperator) throws IOException {
        FilePath deserialize;
        if (str == null && batchOperator == null) {
            throw new IllegalArgumentException("One of model or modelFilePath should be set.");
        }
        if (batchOperator != null && !(batchOperator instanceof AkSourceBatchOp) && !(batchOperator instanceof MemSourceBatchOp)) {
            return Tuple2.of(DirectReader.collect(batchOperator), batchOperator.getSchema());
        }
        if (batchOperator == null) {
            deserialize = FilePath.deserialize(str);
        } else {
            if (batchOperator instanceof MemSourceBatchOp) {
                MTable mt = ((MemSourceBatchOp) batchOperator).getMt();
                return Tuple2.of(new MTableDataBridge(mt), mt.getSchema());
            }
            deserialize = FilePath.deserialize((String) batchOperator.getParams().get(AkSourceParams.FILE_PATH));
        }
        if (deserialize.getFileSystem().exists(deserialize.getPath())) {
            return Tuple2.of(new FilePathDataBridge(deserialize), TableUtil.schemaStr2Schema(AkUtils.getMetaFromPath(deserialize).schemaStr));
        }
        throw new IllegalArgumentException("When use model file path, the model should be sink first. If using pipeline model, it should be save model first.");
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public /* bridge */ /* synthetic */ StreamOperator linkFrom(StreamOperator[] streamOperatorArr) {
        return linkFrom((StreamOperator<?>[]) streamOperatorArr);
    }
}
