package com.alibaba.alink.operator.stream.onlinelearning;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithSecondInputSpec;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalModelException;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.io.directreader.DataBridge;
import com.alibaba.alink.common.io.directreader.DirectReader;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.onlinelearning.BaseOnlinePredictStreamOp;
import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil;
import com.alibaba.alink.operator.stream.utils.ModelMapStreamOp;
import com.alibaba.alink.operator.stream.utils.PredictProcess;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.function.TriFunction;

@InputPorts(values = {@PortSpec(value = PortType.MODEL, opType = PortSpec.OpType.BATCH), @PortSpec(value = PortType.MODEL_STREAM, opType = PortSpec.OpType.SAME), @PortSpec(value = PortType.DATA, opType = PortSpec.OpType.SAME)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@Internal
@ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES})
@ReservedColsWithSecondInputSpec
@NameCn("在线学习预测基类")
/* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/BaseOnlinePredictStreamOp.class */
public class BaseOnlinePredictStreamOp<T extends BaseOnlinePredictStreamOp<T>> extends StreamOperator<T> {
    DataBridge dataBridge;
    private final BatchOperator<?> model;
    private final TriFunction<TableSchema, TableSchema, Params, ModelMapper> mapperBuilder;

    public BaseOnlinePredictStreamOp(BatchOperator<?> batchOperator, TriFunction<TableSchema, TableSchema, Params, ModelMapper> triFunction, Params params) {
        super(params);
        this.dataBridge = null;
        this.model = batchOperator;
        this.mapperBuilder = triFunction;
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public T linkFrom(StreamOperator<?>... streamOperatorArr) {
        checkOpSize(2, streamOperatorArr);
        try {
            if (this.model == null) {
                throw new AkIllegalModelException("online algo: initial model is null. Please set a valid initial model.");
            }
            this.dataBridge = DirectReader.collect(this.model);
            DataStream<Row> broadcastStream = ModelMapStreamOp.broadcastStream(streamOperatorArr[0].getDataStream());
            TypeInformation[] typeInformationArr = new TypeInformation[3];
            String[] strArr = new String[3];
            for (int i = 0; i < 3; i++) {
                strArr[i] = streamOperatorArr[0].getSchema().getFieldNames()[i + 2];
                typeInformationArr[i] = streamOperatorArr[0].getSchema().getFieldTypes()[i + 2];
            }
            TableSchema tableSchema = new TableSchema(strArr, typeInformationArr);
            setOutputTable(DataStreamConversionUtil.toTable(getMLEnvironmentId(), (DataStream<Row>) streamOperatorArr[1].getDataStream().connect(broadcastStream).flatMap(new PredictProcess(tableSchema, streamOperatorArr[1].getSchema(), getParams(), this.mapperBuilder, this.dataBridge, 0, 1)), ((ModelMapper) this.mapperBuilder.apply(tableSchema, streamOperatorArr[1].getSchema(), getParams())).getOutputSchema()));
            return this;
        } catch (Exception e) {
            e.printStackTrace();
            throw new AkUnclassifiedErrorException(e.toString());
        }
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public /* bridge */ /* synthetic */ StreamOperator linkFrom(StreamOperator[] streamOperatorArr) {
        return linkFrom((StreamOperator<?>[]) streamOperatorArr);
    }
}
