package com.alibaba.alink.operator.stream.onlinelearning;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkUnimplementedOperationException;
import com.alibaba.alink.common.io.directreader.DataBridge;
import com.alibaba.alink.common.io.directreader.DirectReader;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.MapperChain;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.modelstream.ModelStreamUtils;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.onlinelearning.kernel.FmOnlineLearningKernel;
import com.alibaba.alink.operator.stream.onlinelearning.kernel.LinearOnlineLearningKernel;
import com.alibaba.alink.operator.stream.onlinelearning.kernel.OnlineLearningKernel;
import com.alibaba.alink.operator.stream.onlinelearning.kernel.SoftmaxOnlineLearningKernel;
import com.alibaba.alink.operator.stream.source.ModelStreamFileSourceStreamOp;
import com.alibaba.alink.operator.stream.source.NumSeqSourceStreamOp;
import com.alibaba.alink.params.onlinelearning.OnlineLearningTrainParams;
import com.alibaba.alink.pipeline.ModelExporterUtils;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.PipelineStageBase;
import com.alibaba.alink.pipeline.classification.FmClassificationModel;
import com.alibaba.alink.pipeline.classification.LinearSvmModel;
import com.alibaba.alink.pipeline.classification.LogisticRegressionModel;
import com.alibaba.alink.pipeline.classification.SoftmaxModel;
import com.alibaba.alink.pipeline.regression.FmRegressionModel;
import com.alibaba.alink.pipeline.regression.LinearRegressionModel;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.CheckpointingMode;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.IterativeStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@NameCn("在线学习")
@NameEn("Online learning")
/* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/OnlineLearningStreamOp.class */
public class OnlineLearningStreamOp extends StreamOperator<OnlineLearningStreamOp> implements OnlineLearningTrainParams<OnlineLearningStreamOp> {
    private static final long serialVersionUID = 3688413917992858013L;
    private final DataBridge dataBridge;
    private final String modelSchemeStr;

    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/OnlineLearningStreamOp$ModelUpdater.class */
    public static class ModelUpdater extends RichCoFlatMapFunction<Row, Tuple2<Long, Object>, Tuple2<Long, Object>> implements CheckpointedFunction {
        private static final long serialVersionUID = 7338858137860097282L;
        private final DataBridge dataBridge;
        private int vectorTrainIdx;
        private int labelIdx;
        private int[] featureIdx;
        private List<Row> localBatchDataBuffer;
        private List<Object> gradientBuffer;
        private int maxNumBatches;
        private final String modelSchemaStr;
        private final String dataSchemaStr;
        private List<Row> pipeModelRows;
        private long linearModelId;
        private int[] rowFieldMapping;
        private int idIdx;
        private final Params params;
        private MapperChain mapperChain;
        private OnlineLearningKernel kernel;
        private transient ListState<OnlineLearningKernel> modelState;
        private long gradientVersion = 0;
        private boolean isUpdatedModel = true;
        private final Map<Timestamp, List<Row>> buffers = new HashMap();
        private boolean isOutputModel = true;

        public ModelUpdater(DataBridge dataBridge, Params params, String str, String str2) {
            this.dataBridge = dataBridge;
            this.params = params;
            this.modelSchemaStr = str;
            this.dataSchemaStr = str2;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.maxNumBatches = Math.max(1, 100000 * getRuntimeContext().getNumberOfParallelSubtasks());
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
            this.modelState.clear();
            this.modelState.add(this.kernel);
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
            this.modelState = functionInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("StreamingOnlineModelState", TypeInformation.of(OnlineLearningKernel.class)));
            if (functionInitializationContext.isRestored()) {
                Iterator it = ((Iterable) this.modelState.get()).iterator();
                while (it.hasNext()) {
                    this.kernel = (OnlineLearningKernel) it.next();
                }
            } else {
                if (functionInitializationContext.isRestored()) {
                    return;
                }
                this.pipeModelRows = DirectReader.directRead(this.dataBridge);
                loadPipelineModel(TableUtil.schemaStr2Schema(this.modelSchemaStr), TableUtil.schemaStr2Schema(this.dataSchemaStr));
            }
        }

        public void flatMap1(Row row, Collector<Tuple2<Long, Object>> collector) throws Exception {
            Vector vector;
            if (this.localBatchDataBuffer == null) {
                this.localBatchDataBuffer = new ArrayList();
                this.gradientBuffer = new ArrayList();
            }
            this.localBatchDataBuffer.add(row);
            if (!this.isUpdatedModel) {
                if (this.localBatchDataBuffer.size() > this.maxNumBatches) {
                    this.localBatchDataBuffer.subList(0, Math.min(this.maxNumBatches - 1, (this.maxNumBatches * 9) / 10)).clear();
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println("remove batches happened. ");
                        return;
                    }
                    return;
                }
                return;
            }
            this.kernel.getGradient().clear();
            Iterator<Row> it = this.localBatchDataBuffer.iterator();
            while (it.hasNext()) {
                Row map = this.mapperChain.map(it.next());
                if (this.vectorTrainIdx == -1) {
                    vector = new DenseVector(this.featureIdx.length);
                    for (int i = 0; i < this.featureIdx.length; i++) {
                        vector.set(i, Double.parseDouble(map.getField(this.featureIdx[i]).toString()));
                    }
                } else {
                    vector = VectorUtil.getVector(map.getField(this.vectorTrainIdx));
                }
                this.kernel.calcGradient(vector, map.getField(this.labelIdx));
            }
            this.localBatchDataBuffer.clear();
            this.isUpdatedModel = false;
            long j = this.gradientVersion;
            this.gradientVersion = j + 1;
            collector.collect(Tuple2.of(Long.valueOf(j), this.kernel.getGradient()));
        }

        public void flatMap2(Tuple2<Long, Object> tuple2, Collector<Tuple2<Long, Object>> collector) {
            if (tuple2.f1 instanceof Map) {
                this.gradientBuffer.add(tuple2.f1);
                if (this.isUpdatedModel) {
                    return;
                }
                this.kernel.updateModel(this.gradientBuffer.remove(0));
                this.isUpdatedModel = true;
                this.isOutputModel = true;
                return;
            }
            if (tuple2.f1 instanceof Row) {
                Row row = (Row) tuple2.f1;
                Timestamp timestamp = (Timestamp) row.getField(0);
                long longValue = ((Long) row.getField(1)).longValue();
                Row genRowWithoutIdentifier = ModelStreamUtils.genRowWithoutIdentifier(row, 0, 1);
                if (this.buffers.containsKey(timestamp)) {
                    this.buffers.get(timestamp).add(genRowWithoutIdentifier);
                } else {
                    ArrayList arrayList = new ArrayList(0);
                    arrayList.add(genRowWithoutIdentifier);
                    this.buffers.put(timestamp, arrayList);
                }
                if (this.buffers.get(timestamp).size() == ((int) longValue)) {
                    try {
                        this.pipeModelRows = this.buffers.get(timestamp);
                        loadPipelineModel(TableUtil.schemaStr2Schema(this.modelSchemaStr), TableUtil.schemaStr2Schema(this.dataSchemaStr));
                    } catch (Exception e) {
                        System.err.println("test Model stream updating failed. Please check your model stream.");
                    }
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println("rebase pipeline model.");
                        return;
                    }
                    return;
                }
                return;
            }
            if (!(tuple2.f1 instanceof Long)) {
                throw new AkIllegalDataException("feedback data type err, must be a Map, Row or Long.");
            }
            if (this.isOutputModel) {
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("output model begin...");
                }
                ArrayList<Row> arrayList2 = new ArrayList(this.pipeModelRows.size());
                for (Row row2 : this.pipeModelRows) {
                    if (((Long) row2.getField(this.idIdx)).longValue() != this.linearModelId) {
                        arrayList2.add(row2);
                    }
                }
                List<Row> serializeModel = this.kernel.serializeModel();
                Timestamp timestamp2 = new Timestamp(System.currentTimeMillis());
                long size = serializeModel.size() + arrayList2.size();
                int arity = ((Row) arrayList2.get(0)).getArity();
                for (Row row3 : arrayList2) {
                    Row row4 = new Row(row3.getArity() + 2);
                    row4.setField(0, timestamp2);
                    row4.setField(1, Long.valueOf(size));
                    for (int i = 0; i < arity; i++) {
                        row4.setField(i + 2, row3.getField(i));
                    }
                    collector.collect(Tuple2.of(-1L, row4));
                }
                for (Row row5 : serializeModel) {
                    Row row6 = new Row(arity + 2);
                    row6.setField(0, timestamp2);
                    row6.setField(1, Long.valueOf(size));
                    row6.setField(2, Long.valueOf(this.linearModelId));
                    for (int i2 = 0; i2 < this.rowFieldMapping.length; i2++) {
                        row6.setField(this.rowFieldMapping[i2] + 3, row5.getField(i2));
                    }
                    collector.collect(Tuple2.of(-1L, row6));
                }
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("output model OK.");
                }
                this.isOutputModel = false;
            }
        }

        private void loadPipelineModel(TableSchema tableSchema, TableSchema tableSchema2) {
            List<Tuple3<PipelineStageBase<?>, TableSchema, List<Row>>> loadStagesFromPipelineModel = ModelExporterUtils.loadStagesFromPipelineModel(this.pipeModelRows, tableSchema);
            this.idIdx = TableUtil.findColIndexWithAssert(tableSchema, "id");
            this.linearModelId = loadStagesFromPipelineModel.size();
            PipelineStageBase pipelineStageBase = (PipelineStageBase) loadStagesFromPipelineModel.get(loadStagesFromPipelineModel.size() - 1).f0;
            Iterator<Row> it = this.pipeModelRows.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Row next = it.next();
                if (((Long) next.getField(this.idIdx)).longValue() == -1) {
                    ModelExporterUtils.StageNode[] deserializePipelineStagesFromMeta = ModelExporterUtils.deserializePipelineStagesFromMeta(next, tableSchema);
                    this.rowFieldMapping = deserializePipelineStagesFromMeta[deserializePipelineStagesFromMeta.length - 1].schemaIndices;
                    break;
                }
            }
            MapperChain loadMapperListFromStages = ModelExporterUtils.loadMapperListFromStages(loadStagesFromPipelineModel, tableSchema2);
            loadMapperListFromStages.open();
            Mapper[] mapperArr = new Mapper[loadMapperListFromStages.getMappers().length - 1];
            for (int i = 0; i < loadMapperListFromStages.getMappers().length - 1; i++) {
                mapperArr[i] = loadMapperListFromStages.getMappers()[i];
            }
            if (pipelineStageBase instanceof LogisticRegressionModel) {
                this.kernel = new LinearOnlineLearningKernel(this.params, LinearModelType.LR);
            } else if (pipelineStageBase instanceof FmClassificationModel) {
                this.kernel = new FmOnlineLearningKernel(this.params, true);
            } else if (pipelineStageBase instanceof FmRegressionModel) {
                this.kernel = new FmOnlineLearningKernel(this.params, false);
            } else if (pipelineStageBase instanceof SoftmaxModel) {
                this.kernel = new SoftmaxOnlineLearningKernel(this.params);
            } else if (pipelineStageBase instanceof LinearRegressionModel) {
                this.kernel = new LinearOnlineLearningKernel(this.params, LinearModelType.LinearReg);
            } else {
                if (!(pipelineStageBase instanceof LinearSvmModel)) {
                    throw new AkUnimplementedOperationException("Not support this stage yet, online learning only support LR, FMClassification, FMRegression, SVM, LinearReg, Softmax.");
                }
                this.kernel = new LinearOnlineLearningKernel(this.params, LinearModelType.SVM);
            }
            this.mapperChain = new MapperChain(mapperArr);
            TableSchema outputSchema = mapperArr[mapperArr.length - 1].getOutputSchema();
            this.kernel.deserializeModel((List) loadStagesFromPipelineModel.get(loadStagesFromPipelineModel.size() - 1).f2);
            this.featureIdx = this.kernel.getFeatureIndices(outputSchema);
            this.labelIdx = this.kernel.getLabelIdx(outputSchema);
            this.vectorTrainIdx = this.kernel.getVectorIdx(outputSchema);
        }

        public /* bridge */ /* synthetic */ void flatMap2(Object obj, Collector collector) throws Exception {
            flatMap2((Tuple2<Long, Object>) obj, (Collector<Tuple2<Long, Object>>) collector);
        }

        public /* bridge */ /* synthetic */ void flatMap1(Object obj, Collector collector) throws Exception {
            flatMap1((Row) obj, (Collector<Tuple2<Long, Object>>) collector);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/OnlineLearningStreamOp$ReduceGradient.class */
    public static class ReduceGradient implements ReduceFunction<Tuple2<Long, Object>> {
        public Tuple2<Long, Object> reduce(Tuple2<Long, Object> tuple2, Tuple2<Long, Object> tuple22) {
            Map map = (Map) tuple2.f1;
            Map map2 = (Map) tuple22.f1;
            Iterator it = map2.keySet().iterator();
            while (it.hasNext()) {
                int intValue = ((Integer) it.next()).intValue();
                if (map.containsKey(Integer.valueOf(intValue))) {
                    for (int i = 0; i < ((double[]) map.get(Integer.valueOf(intValue))).length; i++) {
                        double[] dArr = (double[]) map.get(Integer.valueOf(intValue));
                        int i2 = i;
                        dArr[i2] = dArr[i2] + ((double[]) map2.get(Integer.valueOf(intValue)))[i];
                    }
                } else {
                    map.put(Integer.valueOf(intValue), map2.get(Integer.valueOf(intValue)));
                }
            }
            return tuple2;
        }
    }

    public OnlineLearningStreamOp(PipelineModel pipelineModel) {
        this(pipelineModel.save(), new Params());
    }

    public OnlineLearningStreamOp(BatchOperator<?> batchOperator) {
        this(batchOperator, new Params());
    }

    public OnlineLearningStreamOp(BatchOperator<?> batchOperator, Params params) {
        super(params);
        if (batchOperator == null) {
            throw new AkIllegalArgumentException("Online algo: initial model is null. Please set the initial model.");
        }
        this.dataBridge = DirectReader.collect(batchOperator);
        this.modelSchemeStr = TableUtil.schema2SchemaStr(batchOperator.getSchema());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public OnlineLearningStreamOp linkFrom(StreamOperator<?>... streamOperatorArr) {
        StreamExecutionEnvironment streamExecutionEnvironment = MLEnvironmentFactory.get(getMLEnvironmentId()).getStreamExecutionEnvironment();
        checkOpSize(1, streamOperatorArr);
        String schema2SchemaStr = TableUtil.schema2SchemaStr(streamOperatorArr[0].getSchema());
        int parallelism = streamExecutionEnvironment.getParallelism();
        streamExecutionEnvironment.getCheckpointConfig().setCheckpointingMode(CheckpointingMode.EXACTLY_ONCE);
        streamExecutionEnvironment.getCheckpointConfig().setForceCheckpointing(true);
        streamExecutionEnvironment.setBufferTimeout(20L);
        Params params = getParams();
        final int intValue = getTimeInterval().intValue();
        DataStream flatMap = ((NumSeqSourceStreamOp) new NumSeqSourceStreamOp(0L, 0L).setMLEnvironmentId(getMLEnvironmentId())).getDataStream().flatMap(new FlatMapFunction<Row, Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineLearningStreamOp.1
            public void flatMap(Row row, Collector<Tuple2<Long, Object>> collector) throws Exception {
                long j = 0;
                while (true) {
                    long j2 = j;
                    if (j2 >= Long.MAX_VALUE) {
                        return;
                    }
                    Thread.sleep(intValue * 1000);
                    collector.collect(Tuple2.of(0L, 0L));
                    j = j2 + 1;
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Tuple2<Long, Object>>) collector);
            }
        });
        DataStream rebalance = streamOperatorArr[0].getDataStream().rebalance();
        SingleOutputStreamOperator map = ModelStreamUtils.useModelStreamFile(params) ? ((StreamOperator) new ModelStreamFileSourceStreamOp().setFilePath(getModelStreamFilePath()).setScanInterval(getModelStreamScanInterval()).setStartTime(getModelStreamStartTime()).setSchemaStr(this.modelSchemeStr).setMLEnvironmentId(streamOperatorArr[0].getMLEnvironmentId())).getDataStream().map(new MapFunction<Row, Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineLearningStreamOp.2
            public Tuple2<Long, Object> map(Row row) {
                return Tuple2.of(-1L, row);
            }
        }) : null;
        IterativeStream.ConnectedIterativeStreams withFeedbackType = rebalance.iterate(Long.MAX_VALUE).withFeedbackType(TypeInformation.of(new TypeHint<Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineLearningStreamOp.3
        }));
        SingleOutputStreamOperator flatMap2 = withFeedbackType.flatMap(new ModelUpdater(this.dataBridge, params, this.modelSchemeStr, schema2SchemaStr));
        DataStream parallelism2 = flatMap2.filter(new FilterFunction<Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineLearningStreamOp.5
            private static final long serialVersionUID = -5436758453355074895L;

            public boolean filter(Tuple2<Long, Object> tuple2) {
                return ((Long) tuple2.f0).longValue() >= 0;
            }
        }).keyBy(new int[]{0}).countWindowAll(parallelism).reduce(new ReduceGradient()).map(new MapFunction<Tuple2<Long, Object>, Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineLearningStreamOp.4
            public Tuple2<Long, Object> map(Tuple2<Long, Object> tuple2) {
                return tuple2;
            }
        }).setParallelism(parallelism);
        DataStream broadcast = (map == null ? parallelism2 : parallelism2.union(new DataStream[]{map})).broadcast();
        SingleOutputStreamOperator map2 = flatMap2.filter(new FilterFunction<Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineLearningStreamOp.7
            private static final long serialVersionUID = 4204787383191799107L;

            public boolean filter(Tuple2<Long, Object> tuple2) {
                return ((Long) tuple2.f0).longValue() < 0;
            }
        }).map(new MapFunction<Tuple2<Long, Object>, Row>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineLearningStreamOp.6
            public Row map(Tuple2<Long, Object> tuple2) {
                return (Row) tuple2.f1;
            }
        });
        withFeedbackType.closeWith(broadcast.union(new DataStream[]{flatMap}));
        TableSchema schemaStr2Schema = TableUtil.schemaStr2Schema(this.modelSchemeStr);
        TypeInformation<?>[] typeInformationArr = new TypeInformation[schemaStr2Schema.getFieldNames().length + 2];
        String[] strArr = new String[schemaStr2Schema.getFieldNames().length + 2];
        strArr[0] = ModelStreamUtils.MODEL_STREAM_TIMESTAMP_COLUMN_NAME;
        strArr[1] = ModelStreamUtils.MODEL_STREAM_COUNT_COLUMN_NAME;
        typeInformationArr[0] = ModelStreamUtils.MODEL_STREAM_TIMESTAMP_COLUMN_TYPE;
        typeInformationArr[1] = ModelStreamUtils.MODEL_STREAM_COUNT_COLUMN_TYPE;
        for (int i = 0; i < schemaStr2Schema.getFieldNames().length; i++) {
            typeInformationArr[i + 2] = schemaStr2Schema.getFieldTypes()[i];
            strArr[i + 2] = schemaStr2Schema.getFieldNames()[i];
        }
        setOutput(map2, strArr, typeInformationArr);
        return this;
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public /* bridge */ /* synthetic */ OnlineLearningStreamOp linkFrom(StreamOperator[] streamOperatorArr) {
        return linkFrom((StreamOperator<?>[]) streamOperatorArr);
    }
}
