package com.alibaba.alink.operator.stream.onlinelearning;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.exceptions.AkIllegalModelException;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.io.directreader.DataBridge;
import com.alibaba.alink.common.io.directreader.DirectReader;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.linear.LabelTypeEnum;
import com.alibaba.alink.operator.common.linear.LinearModelData;
import com.alibaba.alink.operator.common.linear.LinearModelDataConverter;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.modelstream.ModelStreamUtils;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.source.ModelStreamFileSourceStreamOp;
import com.alibaba.alink.params.onlinelearning.FtrlTrainParams;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.CheckpointingMode;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.IterativeStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@NameCn("Ftrl在线训练")
@NameEn("Follow the regularized leader model training")
/* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/FtrlTrainStreamOp.class */
public class FtrlTrainStreamOp extends StreamOperator<FtrlTrainStreamOp> implements FtrlTrainParams<FtrlTrainStreamOp> {
    private static final long serialVersionUID = 3688413917992858013L;
    private final DataBridge dataBridge;
    private final String modelSchemeStr;

    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/FtrlTrainStreamOp$ModelUpdater.class */
    public static class ModelUpdater extends RichCoFlatMapFunction<Row[], Tuple2<Long, Object>, Tuple2<Long, Object>> implements CheckpointedFunction {
        private static final long serialVersionUID = 7338858137860097282L;
        private final DataBridge dataBridge;
        private transient DenseVector coefficientVector;
        private final double alpha;
        private final double beta;
        private final double l1;
        private final double l2;
        private double[] nParam;
        private double[] zParam;
        private final boolean hasInterceptItem;
        private final int vectorTrainIdx;
        private final int labelIdx;
        private final int[] featureIdx;
        private transient ListState<Tuple4<double[], double[], double[], Object[]>> modelState;
        private List<Row[]> localBatchDataBuffer;
        private List<SparseVector> gradientBuffer;
        private final int batchSize;
        private int maxNumBatches;
        private long gradientVersion = 0;
        private boolean isUpdatedModel = true;
        private final Map<Integer, double[]> sparseGradient = new HashMap();
        private Object[] labelValues = null;

        public ModelUpdater(DataBridge dataBridge, Params params, int[] iArr, int i, int i2) {
            this.dataBridge = dataBridge;
            this.alpha = ((Double) params.get(FtrlTrainParams.ALPHA)).doubleValue();
            this.beta = ((Double) params.get(FtrlTrainParams.BETA)).doubleValue();
            this.l1 = ((Double) params.get(FtrlTrainParams.L_1)).doubleValue();
            this.l2 = ((Double) params.get(FtrlTrainParams.L_2)).doubleValue();
            this.batchSize = ((Integer) params.get(FtrlTrainParams.MINI_BATCH_SIZE)).intValue();
            this.hasInterceptItem = ((Boolean) params.get(FtrlTrainParams.WITH_INTERCEPT)).booleanValue();
            this.vectorTrainIdx = i;
            this.labelIdx = i2;
            this.featureIdx = iArr;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.maxNumBatches = Math.max(1, (100000 * getRuntimeContext().getNumberOfParallelSubtasks()) / this.batchSize);
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
            Tuple4 of = Tuple4.of(this.coefficientVector.getData(), this.nParam, this.zParam, this.labelValues);
            this.modelState.clear();
            this.modelState.add(of);
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
            this.modelState = functionInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("StreamingOnlineModelState", TypeInformation.of(new TypeHint<Tuple4<double[], double[], double[], Object[]>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp.ModelUpdater.1
            })));
            if (!functionInitializationContext.isRestored()) {
                LinearModelData load = new LinearModelDataConverter().load(DirectReader.directRead(this.dataBridge));
                this.coefficientVector = load.coefVector;
                this.nParam = new double[this.coefficientVector.size()];
                this.zParam = new double[this.coefficientVector.size()];
                this.labelValues = load.labelValues;
                return;
            }
            for (Tuple4 tuple4 : (Iterable) this.modelState.get()) {
                this.coefficientVector = new DenseVector((double[]) tuple4.f0);
                this.nParam = (double[]) tuple4.f1;
                this.zParam = (double[]) tuple4.f2;
                this.labelValues = (Object[]) tuple4.f3;
            }
        }

        public void flatMap1(Row[] rowArr, Collector<Tuple2<Long, Object>> collector) throws Exception {
            Vector vector;
            if (this.localBatchDataBuffer == null) {
                this.localBatchDataBuffer = new ArrayList();
                this.gradientBuffer = new ArrayList();
            }
            this.localBatchDataBuffer.add(rowArr);
            if (!this.isUpdatedModel) {
                if (this.localBatchDataBuffer.size() > this.maxNumBatches) {
                    this.localBatchDataBuffer.subList(0, Math.min(this.maxNumBatches - 1, (this.maxNumBatches * 9) / 10)).clear();
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println("remove batches happened. ");
                        return;
                    }
                    return;
                }
                return;
            }
            this.sparseGradient.clear();
            for (Row[] rowArr2 : this.localBatchDataBuffer) {
                for (Row row : rowArr2) {
                    if (this.vectorTrainIdx == -1) {
                        vector = new DenseVector(this.featureIdx.length);
                        for (int i = 0; i < this.featureIdx.length; i++) {
                            vector.set(i, Double.parseDouble(row.getField(this.featureIdx[i]).toString()));
                        }
                    } else {
                        vector = VectorUtil.getVector(row.getField(this.vectorTrainIdx));
                    }
                    if (this.hasInterceptItem) {
                        vector = vector.prefix(1.0d);
                    }
                    double d = row.getField(this.labelIdx).equals(this.labelValues[0]) ? 1.0d : Criteria.INVALID_GAIN;
                    double exp = 1.0d / (1.0d + Math.exp(-this.coefficientVector.dot(vector)));
                    if (vector instanceof DenseVector) {
                        DenseVector denseVector = (DenseVector) vector;
                        for (int i2 = 0; i2 < this.coefficientVector.size(); i2++) {
                            if (this.sparseGradient.containsKey(Integer.valueOf(i2))) {
                                double[] dArr = this.sparseGradient.get(Integer.valueOf(i2));
                                dArr[0] = dArr[0] + ((exp - d) * denseVector.getData()[i2]);
                                double[] dArr2 = this.sparseGradient.get(Integer.valueOf(i2));
                                dArr2[1] = dArr2[1] + 1.0d;
                            } else {
                                this.sparseGradient.put(Integer.valueOf(i2), new double[]{(exp - d) * denseVector.getData()[i2], 1.0d});
                            }
                        }
                    } else {
                        SparseVector sparseVector = (SparseVector) vector;
                        for (int i3 = 0; i3 < sparseVector.getIndices().length; i3++) {
                            int i4 = sparseVector.getIndices()[i3];
                            if (this.sparseGradient.containsKey(Integer.valueOf(i4))) {
                                double[] dArr3 = this.sparseGradient.get(Integer.valueOf(i4));
                                dArr3[0] = dArr3[0] + ((exp - d) * sparseVector.getValues()[i3]);
                                double[] dArr4 = this.sparseGradient.get(Integer.valueOf(i4));
                                dArr4[1] = dArr4[1] + 1.0d;
                            } else {
                                this.sparseGradient.put(Integer.valueOf(i4), new double[]{(exp - d) * sparseVector.getValues()[i3], 1.0d});
                            }
                        }
                    }
                }
            }
            this.localBatchDataBuffer.clear();
            this.isUpdatedModel = false;
            long j = this.gradientVersion;
            this.gradientVersion = j + 1;
            collector.collect(Tuple2.of(Long.valueOf(j), this.sparseGradient));
        }

        public void flatMap2(Tuple2<Long, Object> tuple2, Collector<Tuple2<Long, Object>> collector) {
            if (tuple2.f1 instanceof SparseVector) {
                this.gradientBuffer.add((SparseVector) tuple2.f1);
                if (this.isUpdatedModel) {
                    return;
                }
                updateModel(this.coefficientVector, this.gradientBuffer.remove(0));
                this.isUpdatedModel = true;
                return;
            }
            if (!(tuple2.f1 instanceof DenseVector)) {
                if (!(tuple2.f1 instanceof Long)) {
                    throw new AkUnclassifiedErrorException("feedback data type err, must be a Map, Long, or DenseVector.");
                }
                collector.collect(Tuple2.of(-1L, Tuple2.of(this.labelValues, this.coefficientVector)));
            } else {
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("rebase linear model.");
                }
                this.coefficientVector = (DenseVector) tuple2.f1;
                this.nParam = new double[this.coefficientVector.size()];
                this.zParam = new double[this.coefficientVector.size()];
            }
        }

        private void updateModel(DenseVector denseVector, SparseVector sparseVector) {
            int[] indices = sparseVector.getIndices();
            double[] values = sparseVector.getValues();
            for (int i = 0; i < indices.length; i++) {
                int i2 = indices[i];
                double sqrt = (Math.sqrt(this.nParam[i2] + (values[i] * values[i])) - Math.sqrt(this.nParam[i2])) / this.alpha;
                double[] dArr = this.zParam;
                dArr[i2] = dArr[i2] + (values[i] - (sqrt * denseVector.getData()[i2]));
                double[] dArr2 = this.nParam;
                dArr2[i2] = dArr2[i2] + (values[i] * values[i]);
                if (Math.abs(this.zParam[i2]) <= this.l1) {
                    denseVector.set(i2, Criteria.INVALID_GAIN);
                } else {
                    denseVector.set(i2, (((this.zParam[i2] < Criteria.INVALID_GAIN ? -1 : 1) * this.l1) - this.zParam[i2]) / (((this.beta + Math.sqrt(this.nParam[i2])) / this.alpha) + this.l2));
                }
            }
        }

        public /* bridge */ /* synthetic */ void flatMap2(Object obj, Collector collector) throws Exception {
            flatMap2((Tuple2<Long, Object>) obj, (Collector<Tuple2<Long, Object>>) collector);
        }

        public /* bridge */ /* synthetic */ void flatMap1(Object obj, Collector collector) throws Exception {
            flatMap1((Row[]) obj, (Collector<Tuple2<Long, Object>>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/FtrlTrainStreamOp$ParseRebaseModel.class */
    public static class ParseRebaseModel implements FlatMapFunction<Row, Tuple2<Long, Object>> {
        private final Map<Timestamp, List<Row>> buffers = new HashMap();
        private final String modelSchemaStr;

        public ParseRebaseModel(String str) {
            this.modelSchemaStr = str;
        }

        public void flatMap(Row row, Collector<Tuple2<Long, Object>> collector) throws Exception {
            Timestamp timestamp = (Timestamp) row.getField(0);
            long longValue = ((Long) row.getField(1)).longValue();
            Row genRowWithoutIdentifier = ModelStreamUtils.genRowWithoutIdentifier(row, 0, 1);
            if (this.buffers.containsKey(timestamp)) {
                this.buffers.get(timestamp).add(genRowWithoutIdentifier);
            } else {
                ArrayList arrayList = new ArrayList(0);
                arrayList.add(genRowWithoutIdentifier);
                this.buffers.put(timestamp, arrayList);
            }
            if (this.buffers.get(timestamp).size() == ((int) longValue)) {
                try {
                    collector.collect(Tuple2.of(-1L, new LinearModelDataConverter(LinearModelDataConverter.extractLabelType(TableUtil.schemaStr2Schema(this.modelSchemaStr))).load(this.buffers.get(timestamp)).coefVector));
                    this.buffers.get(timestamp).clear();
                } catch (Exception e) {
                    System.err.println("Model stream updating failed. Please check your model stream.");
                }
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<Tuple2<Long, Object>>) collector);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/FtrlTrainStreamOp$PrepareBatchSample.class */
    public static class PrepareBatchSample extends RichFlatMapFunction<Row, Row[]> {
        private static final long serialVersionUID = 3738888745125082777L;
        private final int batchSize;
        private final Row[] bufferedData;
        private int idx = 0;

        public PrepareBatchSample(int i) {
            this.batchSize = i;
            this.bufferedData = new Row[i];
        }

        public void flatMap(Row row, Collector<Row[]> collector) throws Exception {
            Row[] rowArr = this.bufferedData;
            int i = this.idx;
            this.idx = i + 1;
            rowArr[i] = row;
            if (this.idx == this.batchSize) {
                collector.collect(this.bufferedData);
                this.idx = 0;
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<Row[]>) collector);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/FtrlTrainStreamOp$WriteModel.class */
    public static class WriteModel extends RichFlatMapFunction<Tuple2<Long, Object>, Row> {
        private static final long serialVersionUID = 828728999893377750L;
        private final LabelTypeEnum.StringTypeEnum type;
        private final String vectorColName;
        private final String[] featureCols;
        private final boolean hasInterceptItem;

        public WriteModel(TypeInformation<?> typeInformation, String str, String[] strArr, boolean z) {
            this.type = LabelTypeEnum.StringTypeEnum.valueOf(typeInformation.toString().toUpperCase());
            this.vectorColName = str;
            this.featureCols = strArr;
            this.hasInterceptItem = z;
        }

        public void flatMap(Tuple2<Long, Object> tuple2, Collector<Row> collector) throws Exception {
            Tuple2 tuple22 = (Tuple2) tuple2.f1;
            LinearModelData linearModelData = new LinearModelData();
            linearModelData.coefVector = (DenseVector) tuple22.f1;
            linearModelData.hasInterceptItem = this.hasInterceptItem;
            linearModelData.vectorColName = this.vectorColName;
            linearModelData.modelName = "Logistic Regression";
            linearModelData.featureNames = this.featureCols;
            linearModelData.labelValues = (Object[]) tuple22.f0;
            linearModelData.vectorSize = this.hasInterceptItem ? linearModelData.coefVector.size() - 1 : linearModelData.coefVector.size();
            linearModelData.linearModelType = LinearModelType.LR;
            RowCollector rowCollector = new RowCollector();
            new LinearModelDataConverter().save(linearModelData, rowCollector);
            List<Row> rows = rowCollector.getRows();
            long currentTimeMillis = System.currentTimeMillis();
            for (Row row : rows) {
                int arity = row.getArity();
                Row row2 = new Row(arity + 2);
                row2.setField(0, new Timestamp(currentTimeMillis));
                row2.setField(1, Long.valueOf(rows.size()));
                for (int i = 0; i < arity; i++) {
                    if (i != 2 || row.getField(i) == null) {
                        row2.setField(2 + i, row.getField(i));
                    } else if (this.type.equals(LabelTypeEnum.StringTypeEnum.BIGINT) || this.type.equals(LabelTypeEnum.StringTypeEnum.LONG)) {
                        row2.setField(2 + i, Long.valueOf(Double.valueOf(row.getField(i).toString()).longValue()));
                    } else if (this.type.equals(LabelTypeEnum.StringTypeEnum.INT) || this.type.equals(LabelTypeEnum.StringTypeEnum.INTEGER)) {
                        row2.setField(2 + i, Integer.valueOf(Double.valueOf(row.getField(i).toString()).intValue()));
                    } else if (this.type.equals(LabelTypeEnum.StringTypeEnum.DOUBLE)) {
                        row2.setField(2 + i, Double.valueOf(row.getField(i).toString()));
                    } else if (this.type.equals(LabelTypeEnum.StringTypeEnum.FLOAT)) {
                        row2.setField(2 + i, Float.valueOf(Double.valueOf(row.getField(i).toString()).floatValue()));
                    } else {
                        row2.setField(2 + i, row.getField(i));
                    }
                }
                collector.collect(row2);
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple2<Long, Object>) obj, (Collector<Row>) collector);
        }
    }

    public FtrlTrainStreamOp(BatchOperator<?> batchOperator) {
        this(batchOperator, new Params());
    }

    public FtrlTrainStreamOp(BatchOperator<?> batchOperator, Params params) {
        super(params);
        if (batchOperator == null) {
            throw new AkIllegalModelException("Online algo: initial model is null. Please set the initial model.");
        }
        this.dataBridge = DirectReader.collect(batchOperator);
        this.modelSchemeStr = TableUtil.schema2SchemaStr(batchOperator.getSchema());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public FtrlTrainStreamOp linkFrom(StreamOperator<?>... streamOperatorArr) {
        StreamExecutionEnvironment streamExecutionEnvironment = MLEnvironmentFactory.get(getMLEnvironmentId()).getStreamExecutionEnvironment();
        int parallelism = streamExecutionEnvironment.getParallelism();
        streamExecutionEnvironment.getCheckpointConfig().setCheckpointingMode(CheckpointingMode.EXACTLY_ONCE);
        streamExecutionEnvironment.getCheckpointConfig().setForceCheckpointing(true);
        checkOpSize(1, streamOperatorArr);
        Params params = getParams();
        String vectorCol = getVectorCol();
        String[] featureCols = getFeatureCols();
        int findColIndexWithAssertAndHint = vectorCol != null ? TableUtil.findColIndexWithAssertAndHint(streamOperatorArr[0].getColNames(), vectorCol) : -1;
        int findColIndexWithAssertAndHint2 = TableUtil.findColIndexWithAssertAndHint(streamOperatorArr[0].getColNames(), getLabelCol());
        int[] iArr = null;
        if (findColIndexWithAssertAndHint == -1) {
            iArr = new int[featureCols.length];
            for (int i = 0; i < featureCols.length; i++) {
                iArr[i] = TableUtil.findColIndexWithAssertAndHint(streamOperatorArr[0].getColNames(), featureCols[i]);
            }
        }
        TypeInformation<?> typeInformation = streamOperatorArr[0].getColTypes()[findColIndexWithAssertAndHint2];
        final int intValue = getTimeInterval().intValue();
        DataStream flatMap = streamExecutionEnvironment.fromElements(new Row[]{Row.of(new Object[]{1})}).flatMap(new FlatMapFunction<Row, Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp.1
            public void flatMap(Row row, Collector<Tuple2<Long, Object>> collector) throws Exception {
                long j = 0;
                while (true) {
                    long j2 = j;
                    if (j2 >= Long.MAX_VALUE) {
                        return;
                    }
                    Thread.sleep(intValue * 1000);
                    collector.collect(Tuple2.of(0L, 0L));
                    j = j2 + 1;
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Tuple2<Long, Object>>) collector);
            }
        });
        SingleOutputStreamOperator flatMap2 = streamOperatorArr[0].getDataStream().rebalance().flatMap(new PrepareBatchSample(Math.max(1, getMiniBatchSize().intValue() / parallelism)));
        SingleOutputStreamOperator flatMap3 = ModelStreamUtils.useModelStreamFile(params) ? ((StreamOperator) new ModelStreamFileSourceStreamOp().setFilePath(getModelStreamFilePath()).setScanInterval(getModelStreamScanInterval()).setStartTime(getModelStreamStartTime()).setSchemaStr(this.modelSchemeStr).setMLEnvironmentId(streamOperatorArr[0].getMLEnvironmentId())).getDataStream().flatMap(new ParseRebaseModel(this.modelSchemeStr)) : null;
        IterativeStream.ConnectedIterativeStreams withFeedbackType = flatMap2.iterate(Long.MAX_VALUE).withFeedbackType(TypeInformation.of(new TypeHint<Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp.2
        }));
        SingleOutputStreamOperator flatMap4 = withFeedbackType.flatMap(new ModelUpdater(this.dataBridge, params, iArr, findColIndexWithAssertAndHint, findColIndexWithAssertAndHint2));
        DataStream parallelism2 = flatMap4.filter(new FilterFunction<Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp.5
            private static final long serialVersionUID = -5436758453355074895L;

            public boolean filter(Tuple2<Long, Object> tuple2) {
                return ((Long) tuple2.f0).longValue() >= 0;
            }
        }).keyBy(new int[]{0}).countWindowAll(parallelism).reduce(new ReduceFunction<Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp.4
            public Tuple2<Long, Object> reduce(Tuple2<Long, Object> tuple2, Tuple2<Long, Object> tuple22) {
                Map map = (Map) tuple2.f1;
                Map map2 = (Map) tuple22.f1;
                Iterator it = map2.keySet().iterator();
                while (it.hasNext()) {
                    int intValue2 = ((Integer) it.next()).intValue();
                    if (map.containsKey(Integer.valueOf(intValue2))) {
                        double[] dArr = (double[]) map.get(Integer.valueOf(intValue2));
                        dArr[0] = dArr[0] + ((double[]) map2.get(Integer.valueOf(intValue2)))[0];
                        double[] dArr2 = (double[]) map.get(Integer.valueOf(intValue2));
                        dArr2[1] = dArr2[1] + ((double[]) map2.get(Integer.valueOf(intValue2)))[1];
                    } else {
                        map.put(Integer.valueOf(intValue2), map2.get(Integer.valueOf(intValue2)));
                    }
                }
                return tuple2;
            }
        }).map(new MapFunction<Tuple2<Long, Object>, Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp.3
            public Tuple2<Long, Object> map(Tuple2<Long, Object> tuple2) {
                Map map = (Map) tuple2.f1;
                int[] iArr2 = new int[map.size()];
                double[] dArr = new double[map.size()];
                int i2 = 0;
                for (Integer num : map.keySet()) {
                    iArr2[i2] = num.intValue();
                    double[] dArr2 = (double[]) map.get(num);
                    int i3 = i2;
                    i2++;
                    dArr[i3] = dArr2[0] / dArr2[1];
                }
                return Tuple2.of(tuple2.f0, new SparseVector(-1, iArr2, dArr));
            }
        }).setParallelism(parallelism);
        DataStream broadcast = (flatMap3 == null ? parallelism2 : parallelism2.union(new DataStream[]{flatMap3})).broadcast();
        SingleOutputStreamOperator flatMap5 = flatMap4.filter(new FilterFunction<Tuple2<Long, Object>>() { // from class: com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp.6
            private static final long serialVersionUID = 4204787383191799107L;

            public boolean filter(Tuple2<Long, Object> tuple2) {
                return ((Long) tuple2.f0).longValue() < 0;
            }
        }).flatMap(new WriteModel(typeInformation, vectorCol, featureCols, getWithIntercept().booleanValue()));
        withFeedbackType.closeWith(broadcast.union(new DataStream[]{flatMap}));
        TableSchema modelSchema = new LinearModelDataConverter(typeInformation).getModelSchema();
        TypeInformation<?>[] typeInformationArr = new TypeInformation[modelSchema.getFieldNames().length + 2];
        String[] strArr = new String[modelSchema.getFieldNames().length + 2];
        strArr[0] = ModelStreamUtils.MODEL_STREAM_TIMESTAMP_COLUMN_NAME;
        strArr[1] = ModelStreamUtils.MODEL_STREAM_COUNT_COLUMN_NAME;
        typeInformationArr[0] = ModelStreamUtils.MODEL_STREAM_TIMESTAMP_COLUMN_TYPE;
        typeInformationArr[1] = ModelStreamUtils.MODEL_STREAM_COUNT_COLUMN_TYPE;
        for (int i2 = 0; i2 < modelSchema.getFieldNames().length; i2++) {
            typeInformationArr[i2 + 2] = modelSchema.getFieldTypes()[i2];
            strArr[i2 + 2] = modelSchema.getFieldNames()[i2];
        }
        setOutput(flatMap5, strArr, typeInformationArr);
        return this;
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public /* bridge */ /* synthetic */ FtrlTrainStreamOp linkFrom(StreamOperator[] streamOperatorArr) {
        return linkFrom((StreamOperator<?>[]) streamOperatorArr);
    }
}
