package com.alibaba.alink.operator.stream.onlinelearning;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.exceptions.AkIllegalModelException;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.io.directreader.DataBridge;
import com.alibaba.alink.common.io.directreader.DirectReader;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp;
import com.alibaba.alink.operator.common.fm.FmModelData;
import com.alibaba.alink.operator.common.fm.FmModelDataConverter;
import com.alibaba.alink.operator.common.modelstream.ModelStreamUtils;
import com.alibaba.alink.operator.common.optim.FmOptimizer;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp;
import com.alibaba.alink.operator.stream.source.ModelStreamFileSourceStreamOp;
import com.alibaba.alink.params.onlinelearning.OnlineFmTrainParams;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.CheckpointingMode;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.IterativeStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@NameCn("在线FM训练")
@Internal
@NameEn("Online Factorization Machine training")
/* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/OnlineFmTrainStreamOp.class */
public final class OnlineFmTrainStreamOp extends StreamOperator<OnlineFmTrainStreamOp> implements OnlineFmTrainParams<OnlineFmTrainStreamOp> {
    private static final long serialVersionUID = -1717242899554835631L;
    DataBridge dataBridge;
    private final String modelSchemeStr;

    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/OnlineFmTrainStreamOp$ModelUpdater.class */
    public static class ModelUpdater extends RichCoFlatMapFunction<Row[], Tuple2<Long, Object>, Tuple2<Long, Object>> implements CheckpointedFunction {
        private static final long serialVersionUID = 7338858137860097282L;
        private final DataBridge dataBridge;
        private transient FmModelData modelData;
        int[] dim;
        private transient BaseFmTrainBatchOp.LossFunction lossFunc;
        private transient BaseFmTrainBatchOp.FmDataFormat nParam;
        private transient BaseFmTrainBatchOp.FmDataFormat zParam;
        private transient ListState<Tuple4<BaseFmTrainBatchOp.FmDataFormat, BaseFmTrainBatchOp.FmDataFormat, BaseFmTrainBatchOp.FmDataFormat, Object[]>> modelState;
        private final int vectorTrainIdx;
        private final int labelIdx;
        private final int[] featureIdx;
        private List<Row[]> localBatchDataBuffer;
        private List<Map<Integer, double[]>> gradientBuffer;
        private final int batchSize;
        private int maxNumBatches;
        private final double l1;
        private final double l2;
        private final double alpha;
        private final double beta;
        private final String modelSchemaStr;
        static final /* synthetic */ boolean $assertionsDisabled;
        private long gradVersion = 0;
        private boolean isUpdatedModel = true;
        private final Map<Integer, Tuple2<Double, double[]>> sparseGradient = new HashMap();
        private final Map<Timestamp, List<Row>> buffers = new HashMap();
        private Object[] labelValues = null;
        double[] regular = new double[3];

        public ModelUpdater(DataBridge dataBridge, Params params, int i, int[] iArr, int i2, String str) {
            this.dataBridge = dataBridge;
            this.batchSize = ((Integer) params.get(OnlineFmTrainParams.MINI_BATCH_SIZE)).intValue();
            this.vectorTrainIdx = i;
            this.labelIdx = i2;
            this.featureIdx = iArr;
            this.regular[0] = ((Double) params.get(OnlineFmTrainParams.LAMBDA_0)).doubleValue();
            this.regular[1] = ((Double) params.get(OnlineFmTrainParams.LAMBDA_1)).doubleValue();
            this.regular[2] = ((Double) params.get(OnlineFmTrainParams.LAMBDA_2)).doubleValue();
            this.dim = new int[3];
            this.dim[0] = ((Boolean) params.get(OnlineFmTrainParams.WITH_INTERCEPT)).booleanValue() ? 1 : 0;
            this.dim[1] = ((Boolean) params.get(OnlineFmTrainParams.WITH_LINEAR_ITEM)).booleanValue() ? 1 : 0;
            this.dim[2] = ((Integer) params.get(OnlineFmTrainParams.NUM_FACTOR)).intValue();
            this.l1 = ((Double) params.get(OnlineFmTrainParams.L_1)).doubleValue();
            this.l2 = ((Double) params.get(OnlineFmTrainParams.L_2)).doubleValue();
            this.alpha = ((Double) params.get(OnlineFmTrainParams.ALPHA)).doubleValue();
            this.beta = ((Double) params.get(OnlineFmTrainParams.BETA)).doubleValue();
            this.lossFunc = new BaseFmTrainBatchOp.LogitLoss();
            this.modelSchemaStr = str;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.maxNumBatches = Math.max(1, (100000 * getRuntimeContext().getNumberOfParallelSubtasks()) / this.batchSize);
            this.lossFunc = new BaseFmTrainBatchOp.LogitLoss();
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
            this.modelState.clear();
            this.modelState.add(Tuple4.of(this.modelData.fmModel, this.nParam, this.zParam, this.labelValues));
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
            this.modelState = functionInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("StreamingOnlineModelState", TypeInformation.of(new TypeHint<Tuple4<BaseFmTrainBatchOp.FmDataFormat, BaseFmTrainBatchOp.FmDataFormat, BaseFmTrainBatchOp.FmDataFormat, Object[]>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineFmTrainStreamOp.ModelUpdater.1
            })));
            if (!functionInitializationContext.isRestored()) {
                this.modelData = new FmModelDataConverter().load(DirectReader.directRead(this.dataBridge));
                double[][] dArr = this.modelData.fmModel.factors;
                this.nParam = new BaseFmTrainBatchOp.FmDataFormat(dArr.length, dArr[0].length, this.modelData.dim, Criteria.INVALID_GAIN);
                this.zParam = new BaseFmTrainBatchOp.FmDataFormat(dArr.length, dArr[0].length, this.modelData.dim, Criteria.INVALID_GAIN);
                this.labelValues = this.modelData.labelValues;
                return;
            }
            for (Tuple4 tuple4 : (Iterable) this.modelState.get()) {
                this.modelData.fmModel = (BaseFmTrainBatchOp.FmDataFormat) tuple4.f0;
                this.nParam = (BaseFmTrainBatchOp.FmDataFormat) tuple4.f1;
                this.zParam = (BaseFmTrainBatchOp.FmDataFormat) tuple4.f2;
                this.labelValues = (Object[]) tuple4.f3;
            }
        }

        public void flatMap1(Row[] rowArr, Collector<Tuple2<Long, Object>> collector) throws Exception {
            Vector vector;
            int[] iArr;
            double[] data;
            if (this.localBatchDataBuffer == null) {
                this.localBatchDataBuffer = new ArrayList();
                this.gradientBuffer = new ArrayList();
            }
            this.localBatchDataBuffer.add(rowArr);
            if (!this.isUpdatedModel) {
                if (this.localBatchDataBuffer.size() > this.maxNumBatches) {
                    this.localBatchDataBuffer.subList(0, Math.min(this.maxNumBatches - 1, (this.maxNumBatches * 9) / 10)).clear();
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println("remove batches happened. ");
                        return;
                    }
                    return;
                }
                return;
            }
            this.sparseGradient.clear();
            for (Row[] rowArr2 : this.localBatchDataBuffer) {
                for (Row row : rowArr2) {
                    if (this.vectorTrainIdx == -1) {
                        vector = new DenseVector(this.featureIdx.length);
                        for (int i = 0; i < this.featureIdx.length; i++) {
                            vector.set(i, Double.parseDouble(row.getField(this.featureIdx[i]).toString()));
                        }
                    } else {
                        vector = VectorUtil.getVector(row.getField(this.vectorTrainIdx));
                    }
                    double d = row.getField(this.labelIdx).equals(this.labelValues[0]) ? 1.0d : Criteria.INVALID_GAIN;
                    Tuple2<Double, double[]> calcY = FmOptimizer.calcY(vector, this.modelData.fmModel, this.dim);
                    double dldy = this.lossFunc.dldy(d, ((Double) calcY.f0).doubleValue());
                    if (vector instanceof SparseVector) {
                        iArr = ((SparseVector) vector).getIndices();
                        data = ((SparseVector) vector).getValues();
                    } else {
                        iArr = new int[vector.size()];
                        for (int i2 = 0; i2 < vector.size(); i2++) {
                            iArr[i2] = i2;
                        }
                        data = ((DenseVector) vector).getData();
                    }
                    if (this.dim[0] > 0) {
                        double d2 = dldy + (this.regular[0] * this.modelData.fmModel.bias);
                        if (this.sparseGradient.containsKey(-1)) {
                            Tuple2<Double, double[]> tuple2 = this.sparseGradient.get(-1);
                            tuple2.f0 = Double.valueOf(((Double) tuple2.f0).doubleValue() + 1.0d);
                            double[] dArr = (double[]) this.sparseGradient.get(-1).f1;
                            dArr[0] = dArr[0] + d2;
                        } else {
                            this.sparseGradient.put(-1, Tuple2.of(Double.valueOf(1.0d), new double[]{d2}));
                        }
                    }
                    double[][] dArr2 = this.modelData.fmModel.factors;
                    for (int i3 = 0; i3 < iArr.length; i3++) {
                        int i4 = iArr[i3];
                        if (this.sparseGradient.containsKey(Integer.valueOf(i4))) {
                            double[] dArr3 = (double[]) this.sparseGradient.get(Integer.valueOf(i4)).f1;
                            Tuple2<Double, double[]> tuple22 = this.sparseGradient.get(Integer.valueOf(i4));
                            tuple22.f0 = Double.valueOf(((Double) tuple22.f0).doubleValue() + 1.0d);
                            if (this.dim[1] > 0) {
                                int i5 = this.dim[2];
                                dArr3[i5] = dArr3[i5] + (dldy * data[i3]) + (this.regular[1] * dArr2[i4][this.dim[2]]);
                            }
                            if (this.dim[2] > 0) {
                                for (int i6 = 0; i6 < this.dim[2]; i6++) {
                                    int i7 = i6;
                                    dArr3[i7] = dArr3[i7] + (dldy * data[i3] * (((double[]) calcY.f1)[i6] - (data[i3] * dArr2[i4][i6]))) + (this.regular[2] * dArr2[i4][i6]);
                                }
                            }
                        } else {
                            double[] dArr4 = new double[this.dim[2] + this.dim[1]];
                            if (this.dim[1] > 0) {
                                dArr4[this.dim[2]] = (dldy * data[i3]) + (this.regular[1] * dArr2[i4][this.dim[2]]);
                            }
                            if (this.dim[2] > 0) {
                                for (int i8 = 0; i8 < this.dim[2]; i8++) {
                                    dArr4[i8] = (dldy * data[i3] * (((double[]) calcY.f1)[i8] - (data[i3] * dArr2[i4][i8]))) + (this.regular[2] * dArr2[i4][i8]);
                                }
                            }
                            this.sparseGradient.put(Integer.valueOf(i4), Tuple2.of(Double.valueOf(1.0d), dArr4));
                        }
                    }
                }
            }
            this.localBatchDataBuffer.clear();
            this.isUpdatedModel = false;
            long j = this.gradVersion;
            this.gradVersion = j + 1;
            collector.collect(Tuple2.of(Long.valueOf(j), this.sparseGradient));
        }

        public void flatMap2(Tuple2<Long, Object> tuple2, Collector<Tuple2<Long, Object>> collector) {
            if (tuple2.f1 instanceof Map) {
                this.gradientBuffer.add((Map) tuple2.f1);
                if (this.isUpdatedModel) {
                    return;
                }
                updateModel(this.gradientBuffer.remove(0));
                this.isUpdatedModel = true;
                return;
            }
            if (!(tuple2.f1 instanceof Row)) {
                if (!(tuple2.f1 instanceof Long)) {
                    throw new AkUnclassifiedErrorException("feedback data type err, must be a Map or DenseVector.");
                }
                Params params = new Params().set((ParamInfo<ParamInfo<String>>) ModelParamName.VECTOR_COL_NAME, (ParamInfo<String>) this.modelData.vectorColName).set((ParamInfo<ParamInfo<String>>) ModelParamName.LABEL_COL_NAME, (ParamInfo<String>) this.modelData.labelColName).set((ParamInfo<ParamInfo<BaseFmTrainBatchOp.Task>>) ModelParamName.TASK, (ParamInfo<BaseFmTrainBatchOp.Task>) BaseFmTrainBatchOp.Task.BINARY_CLASSIFICATION).set((ParamInfo<ParamInfo<Integer>>) ModelParamName.VECTOR_SIZE, (ParamInfo<Integer>) Integer.valueOf(this.modelData.vectorSize)).set((ParamInfo<ParamInfo<String[]>>) ModelParamName.FEATURE_COL_NAMES, (ParamInfo<String[]>) this.modelData.featureColNames).set((ParamInfo<ParamInfo<Object[]>>) ModelParamName.LABEL_VALUES, (ParamInfo<Object[]>) this.labelValues).set((ParamInfo<ParamInfo<int[]>>) ModelParamName.DIM, (ParamInfo<int[]>) this.modelData.dim);
                long currentTimeMillis = System.currentTimeMillis();
                collector.collect(Tuple2.of(Long.valueOf(-currentTimeMillis), Tuple3.of(Long.valueOf(this.modelData.vectorSize + 2), (Object) null, params.toJson())));
                for (int i = 0; i < this.modelData.fmModel.factors.length; i++) {
                    collector.collect(Tuple2.of(Long.valueOf(-currentTimeMillis), Tuple3.of(Long.valueOf(this.modelData.vectorSize + 2), Long.valueOf(i), JsonConverter.toJson(this.modelData.fmModel.factors[i]))));
                }
                collector.collect(Tuple2.of(Long.valueOf(-currentTimeMillis), Tuple3.of(Long.valueOf(this.modelData.vectorSize + 2), -1L, JsonConverter.toJson(new double[]{this.modelData.fmModel.bias}))));
                return;
            }
            Row row = (Row) tuple2.f1;
            Timestamp timestamp = (Timestamp) row.getField(0);
            long longValue = ((Long) row.getField(1)).longValue();
            Row genRowWithoutIdentifier = ModelStreamUtils.genRowWithoutIdentifier(row, 0, 1);
            if (this.buffers.containsKey(timestamp)) {
                this.buffers.get(timestamp).add(genRowWithoutIdentifier);
            } else {
                ArrayList arrayList = new ArrayList(0);
                arrayList.add(genRowWithoutIdentifier);
                this.buffers.put(timestamp, arrayList);
            }
            if (this.buffers.get(timestamp).size() == ((int) longValue)) {
                try {
                    this.modelData.fmModel = new FmModelDataConverter(FmModelDataConverter.extractLabelType(TableUtil.schemaStr2Schema(this.modelSchemaStr))).load(this.buffers.remove(timestamp)).fmModel;
                    double[][] dArr = this.modelData.fmModel.factors;
                    this.nParam = new BaseFmTrainBatchOp.FmDataFormat(dArr.length, dArr[0].length, this.modelData.dim, Criteria.INVALID_GAIN);
                    this.zParam = new BaseFmTrainBatchOp.FmDataFormat(dArr.length, dArr[0].length, this.modelData.dim, Criteria.INVALID_GAIN);
                } catch (Exception e) {
                    System.err.println("test Model stream updating failed. Please check your model stream.");
                }
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("rebase fm model.");
                }
            }
        }

        private void updateModel(Map<Integer, double[]> map) {
            Iterator<Integer> it = map.keySet().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                if (intValue != -1) {
                    double[] dArr = map.get(Integer.valueOf(intValue));
                    if (this.dim[1] > 0) {
                        updateModelVal(intValue, this.dim[2], dArr[this.dim[2]]);
                    }
                    if (this.dim[2] > 0) {
                        for (int i = 0; i < this.dim[2]; i++) {
                            updateModelVal(intValue, i, dArr[i]);
                        }
                    }
                } else {
                    if (!$assertionsDisabled && map.get(Integer.valueOf(intValue)).length != 1) {
                        throw new AssertionError();
                    }
                    double d = map.get(Integer.valueOf(intValue))[0];
                    double sqrt = (Math.sqrt(this.nParam.bias + (d * d)) - Math.sqrt(this.nParam.bias)) / this.alpha;
                    this.zParam.bias += d - (sqrt * this.modelData.fmModel.bias);
                    this.nParam.bias += d * d;
                    if (Math.abs(this.zParam.bias) <= this.l1) {
                        this.modelData.fmModel.bias = Criteria.INVALID_GAIN;
                    } else {
                        this.modelData.fmModel.bias = (((this.zParam.bias < Criteria.INVALID_GAIN ? -1 : 1) * this.l1) - this.zParam.bias) / (((this.beta + Math.sqrt(this.nParam.bias)) / this.alpha) + this.l2);
                    }
                }
            }
        }

        private void updateModelVal(int i, int i2, double d) {
            double sqrt = (Math.sqrt(this.nParam.factors[i][i2] + (d * d)) - Math.sqrt(this.nParam.factors[i][i2])) / this.alpha;
            double[] dArr = this.zParam.factors[i];
            dArr[i2] = dArr[i2] + (d - (sqrt * this.modelData.fmModel.factors[i][i2]));
            double[] dArr2 = this.nParam.factors[i];
            dArr2[i2] = dArr2[i2] + (d * d);
            if (Math.abs(this.zParam.factors[i][i2]) <= this.l1) {
                this.modelData.fmModel.factors[i][i2] = 0.0d;
            } else {
                this.modelData.fmModel.factors[i][i2] = (((this.zParam.factors[i][i2] < Criteria.INVALID_GAIN ? -1 : 1) * this.l1) - this.zParam.factors[i][i2]) / (((this.beta + Math.sqrt(this.nParam.factors[i][i2])) / this.alpha) + this.l2);
            }
        }

        public /* bridge */ /* synthetic */ void flatMap2(Object obj, Collector collector) throws Exception {
            flatMap2((Tuple2<Long, Object>) obj, (Collector<Tuple2<Long, Object>>) collector);
        }

        public /* bridge */ /* synthetic */ void flatMap1(Object obj, Collector collector) throws Exception {
            flatMap1((Row[]) obj, (Collector<Tuple2<Long, Object>>) collector);
        }

        static {
            $assertionsDisabled = !OnlineFmTrainStreamOp.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/OnlineFmTrainStreamOp$WriteModel.class */
    public static class WriteModel extends RichFlatMapFunction<Tuple2<Long, Object>, Row> {
        private static final long serialVersionUID = 3487644568763785149L;

        public void flatMap(Tuple2<Long, Object> tuple2, Collector<Row> collector) throws Exception {
            Timestamp timestamp = new Timestamp(-((Long) tuple2.f0).longValue());
            Tuple3 tuple3 = (Tuple3) tuple2.f1;
            collector.collect(Row.of(new Object[]{timestamp, tuple3.f0, tuple3.f1, tuple3.f2, null}));
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple2<Long, Object>) obj, (Collector<Row>) collector);
        }
    }

    public OnlineFmTrainStreamOp(BatchOperator<?> batchOperator) {
        this(batchOperator, new Params());
    }

    public OnlineFmTrainStreamOp(BatchOperator<?> batchOperator, Params params) {
        super(params);
        if (batchOperator == null) {
            throw new AkIllegalModelException("Online algo: initial model is null. Please set the initial model.");
        }
        this.dataBridge = DirectReader.collect(batchOperator);
        this.modelSchemeStr = TableUtil.schema2SchemaStr(batchOperator.getSchema());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public OnlineFmTrainStreamOp linkFrom(StreamOperator<?>... streamOperatorArr) {
        StreamExecutionEnvironment streamExecutionEnvironment = MLEnvironmentFactory.get(getMLEnvironmentId()).getStreamExecutionEnvironment();
        int parallelism = streamExecutionEnvironment.getParallelism();
        streamExecutionEnvironment.getCheckpointConfig().setCheckpointingMode(CheckpointingMode.EXACTLY_ONCE);
        streamExecutionEnvironment.getCheckpointConfig().setForceCheckpointing(true);
        checkOpSize(1, streamOperatorArr);
        Params params = getParams();
        String vectorCol = getVectorCol();
        String[] featureCols = getFeatureCols();
        int findColIndexWithAssertAndHint = vectorCol != null ? TableUtil.findColIndexWithAssertAndHint(streamOperatorArr[0].getColNames(), vectorCol) : -1;
        int findColIndexWithAssertAndHint2 = TableUtil.findColIndexWithAssertAndHint(streamOperatorArr[0].getColNames(), getLabelCol());
        int[] iArr = null;
        if (findColIndexWithAssertAndHint == -1) {
            iArr = new int[featureCols.length];
            for (int i = 0; i < featureCols.length; i++) {
                iArr[i] = TableUtil.findColIndexWithAssertAndHint(streamOperatorArr[0].getColNames(), featureCols[i]);
            }
        }
        TypeInformation<?> typeInformation = streamOperatorArr[0].getColTypes()[findColIndexWithAssertAndHint2];
        final int intValue = getTimeInterval().intValue();
        DataStream flatMap = streamExecutionEnvironment.fromElements(new Row[]{Row.of(new Object[]{1})}).flatMap(new FlatMapFunction<Row, Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineFmTrainStreamOp.1
            public void flatMap(Row row, Collector<Tuple2<Long, Object>> collector) throws Exception {
                long j = 0;
                while (true) {
                    long j2 = j;
                    if (j2 >= Long.MAX_VALUE) {
                        return;
                    }
                    Thread.sleep(intValue * 1000);
                    collector.collect(Tuple2.of(0L, 0L));
                    j = j2 + 1;
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Tuple2<Long, Object>>) collector);
            }
        });
        SingleOutputStreamOperator flatMap2 = streamOperatorArr[0].getDataStream().rebalance().flatMap(new FtrlTrainStreamOp.PrepareBatchSample(Math.max(1, getMiniBatchSize().intValue() / parallelism)));
        SingleOutputStreamOperator map = ModelStreamUtils.useModelStreamFile(params) ? ((StreamOperator) new ModelStreamFileSourceStreamOp().setFilePath(getModelStreamFilePath()).setScanInterval(getModelStreamScanInterval()).setStartTime(getModelStreamStartTime()).setSchemaStr(this.modelSchemeStr).setMLEnvironmentId(streamOperatorArr[0].getMLEnvironmentId())).getDataStream().map(new MapFunction<Row, Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineFmTrainStreamOp.2
            public Tuple2<Long, Object> map(Row row) {
                return Tuple2.of(-1L, row);
            }
        }) : null;
        IterativeStream.ConnectedIterativeStreams withFeedbackType = flatMap2.iterate(Long.MAX_VALUE).withFeedbackType(TypeInformation.of(new TypeHint<Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineFmTrainStreamOp.3
        }));
        SingleOutputStreamOperator flatMap3 = withFeedbackType.flatMap(new ModelUpdater(this.dataBridge, params, findColIndexWithAssertAndHint, iArr, findColIndexWithAssertAndHint2, this.modelSchemeStr));
        DataStream parallelism2 = flatMap3.filter(new FilterFunction<Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineFmTrainStreamOp.6
            private static final long serialVersionUID = -5436758453355074895L;

            public boolean filter(Tuple2<Long, Object> tuple2) {
                return ((Long) tuple2.f0).longValue() >= 0;
            }
        }).keyBy(new int[]{0}).countWindowAll(parallelism).reduce(new ReduceFunction<Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineFmTrainStreamOp.5
            public Tuple2<Long, Object> reduce(Tuple2<Long, Object> tuple2, Tuple2<Long, Object> tuple22) {
                Map map2 = (Map) tuple2.f1;
                Map map3 = (Map) tuple22.f1;
                Iterator it = map3.keySet().iterator();
                while (it.hasNext()) {
                    int intValue2 = ((Integer) it.next()).intValue();
                    if (map2.containsKey(Integer.valueOf(intValue2))) {
                        Tuple2 tuple23 = (Tuple2) map2.get(Integer.valueOf(intValue2));
                        tuple23.f0 = Double.valueOf(((Double) tuple23.f0).doubleValue() + ((Double) ((Tuple2) map3.get(Integer.valueOf(intValue2))).f0).doubleValue());
                        double[] dArr = (double[]) ((Tuple2) map2.get(Integer.valueOf(intValue2))).f1;
                        for (int i2 = 0; i2 < dArr.length; i2++) {
                            int i3 = i2;
                            dArr[i3] = dArr[i3] + ((double[]) ((Tuple2) map3.get(Integer.valueOf(intValue2))).f1)[i2];
                        }
                    } else {
                        map2.put(Integer.valueOf(intValue2), map3.get(Integer.valueOf(intValue2)));
                    }
                }
                return tuple2;
            }
        }).map(new MapFunction<Tuple2<Long, Object>, Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineFmTrainStreamOp.4
            public Tuple2<Long, Object> map(Tuple2<Long, Object> tuple2) {
                Map map2 = (Map) tuple2.f1;
                HashMap hashMap = new HashMap();
                Iterator it = map2.keySet().iterator();
                while (it.hasNext()) {
                    int intValue2 = ((Integer) it.next()).intValue();
                    double[] dArr = (double[]) ((Tuple2) map2.get(Integer.valueOf(intValue2))).f1;
                    for (int i2 = 0; i2 < dArr.length; i2++) {
                        int i3 = i2;
                        dArr[i3] = dArr[i3] / ((Double) ((Tuple2) map2.get(Integer.valueOf(intValue2))).f0).doubleValue();
                    }
                    hashMap.put(Integer.valueOf(intValue2), dArr);
                }
                return Tuple2.of(tuple2.f0, hashMap);
            }
        }).setParallelism(parallelism);
        DataStream broadcast = (map == null ? parallelism2 : parallelism2.union(new DataStream[]{map})).broadcast();
        SingleOutputStreamOperator flatMap4 = flatMap3.filter(new FilterFunction<Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.OnlineFmTrainStreamOp.7
            private static final long serialVersionUID = 4204787383191799107L;

            public boolean filter(Tuple2<Long, Object> tuple2) {
                return ((Long) tuple2.f0).longValue() < 0;
            }
        }).flatMap(new WriteModel());
        withFeedbackType.closeWith(broadcast.union(new DataStream[]{flatMap}));
        TableSchema modelSchema = new FmModelDataConverter(typeInformation).getModelSchema();
        TypeInformation<?>[] typeInformationArr = new TypeInformation[modelSchema.getFieldNames().length + 2];
        String[] strArr = new String[modelSchema.getFieldNames().length + 2];
        strArr[0] = ModelStreamUtils.MODEL_STREAM_TIMESTAMP_COLUMN_NAME;
        strArr[1] = ModelStreamUtils.MODEL_STREAM_COUNT_COLUMN_NAME;
        typeInformationArr[0] = ModelStreamUtils.MODEL_STREAM_TIMESTAMP_COLUMN_TYPE;
        typeInformationArr[1] = ModelStreamUtils.MODEL_STREAM_COUNT_COLUMN_TYPE;
        for (int i2 = 0; i2 < modelSchema.getFieldNames().length; i2++) {
            typeInformationArr[i2 + 2] = modelSchema.getFieldTypes()[i2];
            strArr[i2 + 2] = modelSchema.getFieldNames()[i2];
        }
        setOutput(flatMap4, strArr, typeInformationArr);
        return this;
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public /* bridge */ /* synthetic */ OnlineFmTrainStreamOp linkFrom(StreamOperator[] streamOperatorArr) {
        return linkFrom((StreamOperator<?>[]) streamOperatorArr);
    }
}
