package com.alibaba.alink.operator.batch.timeseries;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamCond;
import com.alibaba.alink.common.annotation.ParamMutexRule;
import com.alibaba.alink.common.annotation.ParamMutexRules;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.linalg.tensor.FloatTensor;
import com.alibaba.alink.common.linalg.tensor.Shape;
import com.alibaba.alink.common.linalg.tensor.Tensor;
import com.alibaba.alink.common.linalg.tensor.TensorUtil;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.dataproc.SortUtils;
import com.alibaba.alink.operator.common.nlp.WordCountUtil;
import com.alibaba.alink.operator.common.timeseries.DeepARFeaturesGenerator;
import com.alibaba.alink.operator.common.timeseries.DeepARModelDataConverter;
import com.alibaba.alink.operator.common.timeseries.TimestampUtil;
import com.alibaba.alink.operator.common.tree.Preprocessing;
import com.alibaba.alink.params.timeseries.DeepARPreProcessParams;
import com.alibaba.alink.params.timeseries.DeepARTrainParams;
import com.alibaba.alink.params.timeseries.HasTimeFrequency;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "timeCol", allowedTypeCollections = {TypeCollections.TIMESTAMP_TYPES}), @ParamSelectColumnSpec(name = "selectedCol"), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES})})
@NameCn("DeepAR训练")
@ParamMutexRules({@ParamMutexRule(name = "vectorCol", type = ParamMutexRule.ActionType.DISABLE, cond = @ParamCond(name = "selectedCol", type = ParamCond.CondType.WHEN_NOT_NULL)), @ParamMutexRule(name = "selectedCol", type = ParamMutexRule.ActionType.DISABLE, cond = @ParamCond(name = "vectorCol", type = ParamCond.CondType.WHEN_NOT_NULL))})
@NameEn("Deep AR Training")
/* loaded from: input_file:com/alibaba/alink/operator/batch/timeseries/DeepARTrainBatchOp.class */
public class DeepARTrainBatchOp extends BatchOperator<DeepARTrainBatchOp> implements DeepARTrainParams<DeepARTrainBatchOp> {

    /* JADX INFO: Access modifiers changed from: private */
    @Internal
    /* loaded from: input_file:com/alibaba/alink/operator/batch/timeseries/DeepARTrainBatchOp$DeepARPreProcessBatchOp.class */
    public static class DeepARPreProcessBatchOp extends BatchOperator<DeepARPreProcessBatchOp> implements DeepARPreProcessParams<DeepARPreProcessBatchOp> {
        public DeepARPreProcessBatchOp() {
            this(new Params());
        }

        public DeepARPreProcessBatchOp(Params params) {
            super(params);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // com.alibaba.alink.operator.batch.BatchOperator
        public DeepARPreProcessBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
            checkMinOpSize(1, batchOperatorArr);
            BatchOperator<?> batchOperator = batchOperatorArr[0];
            String vectorCol = getParams().contains(VECTOR_COL) ? getVectorCol() : getSelectedCol();
            AkPreconditions.checkNotNull(vectorCol);
            String timeCol = getTimeCol();
            BatchOperator<?> select = Preprocessing.select(batchOperator, timeCol, vectorCol);
            final int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(select.getColNames(), vectorCol);
            final int findColIndexWithAssertAndHint2 = TableUtil.findColIndexWithAssertAndHint(select.getColNames(), timeCol);
            MapOperator map = batchOperatorArr.length > 1 ? batchOperatorArr[1].getDataSet().map(new MapFunction<Row, HasTimeFrequency.TimeFrequency>() { // from class: com.alibaba.alink.operator.batch.timeseries.DeepARTrainBatchOp.DeepARPreProcessBatchOp.1
                public HasTimeFrequency.TimeFrequency map(Row row) throws Exception {
                    return (HasTimeFrequency.TimeFrequency) Params.fromJson(String.valueOf(row.getField(0))).get(HasTimeFrequency.TIME_FREQUENCY);
                }
            }) : select.getDataSet().mapPartition(new MapPartitionFunction<Row, Tuple2<Timestamp, Timestamp>>() { // from class: com.alibaba.alink.operator.batch.timeseries.DeepARTrainBatchOp.DeepARPreProcessBatchOp.4
                public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Timestamp, Timestamp>> collector) {
                    Timestamp timestamp = null;
                    Timestamp timestamp2 = null;
                    Iterator<Row> it = iterable.iterator();
                    while (it.hasNext()) {
                        Timestamp timestamp3 = (Timestamp) it.next().getField(findColIndexWithAssertAndHint2);
                        if (timestamp3 != null) {
                            if (timestamp == null) {
                                timestamp = timestamp3;
                                timestamp2 = timestamp3;
                            } else {
                                timestamp = timestamp.compareTo(timestamp3) < 0 ? timestamp : timestamp3;
                                timestamp2 = timestamp2.compareTo(timestamp3) > 0 ? timestamp2 : timestamp3;
                            }
                        }
                    }
                    if (timestamp != null) {
                        collector.collect(Tuple2.of(timestamp, timestamp2));
                    }
                }
            }).reduce(new ReduceFunction<Tuple2<Timestamp, Timestamp>>() { // from class: com.alibaba.alink.operator.batch.timeseries.DeepARTrainBatchOp.DeepARPreProcessBatchOp.3
                public Tuple2<Timestamp, Timestamp> reduce(Tuple2<Timestamp, Timestamp> tuple2, Tuple2<Timestamp, Timestamp> tuple22) {
                    return Tuple2.of(((Timestamp) tuple2.f0).compareTo((Timestamp) tuple22.f0) < 0 ? (Timestamp) tuple2.f0 : (Timestamp) tuple22.f0, ((Timestamp) tuple2.f1).compareTo((Timestamp) tuple22.f1) > 0 ? (Timestamp) tuple2.f1 : (Timestamp) tuple22.f1);
                }
            }).reduceGroup(new RichGroupReduceFunction<Tuple2<Timestamp, Timestamp>, HasTimeFrequency.TimeFrequency>() { // from class: com.alibaba.alink.operator.batch.timeseries.DeepARTrainBatchOp.DeepARPreProcessBatchOp.2
                private transient long cnt;

                public void open(Configuration configuration) throws Exception {
                    this.cnt = ((Long) getRuntimeContext().getBroadcastVariableWithInitializer(WordCountUtil.COUNT_COL_NAME, new BroadcastVariableInitializer<Tuple1<Long>, Long>() { // from class: com.alibaba.alink.operator.batch.timeseries.DeepARTrainBatchOp.DeepARPreProcessBatchOp.2.1
                        public Long initializeBroadcastVariable(Iterable<Tuple1<Long>> iterable) {
                            return (Long) iterable.iterator().next().f0;
                        }

                        /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                        public /* bridge */ /* synthetic */ Object m299initializeBroadcastVariable(Iterable iterable) {
                            return initializeBroadcastVariable((Iterable<Tuple1<Long>>) iterable);
                        }
                    })).longValue();
                }

                public void reduce(Iterable<Tuple2<Timestamp, Timestamp>> iterable, Collector<HasTimeFrequency.TimeFrequency> collector) {
                    if (this.cnt == 0) {
                        collector.collect(HasTimeFrequency.TimeFrequency.MONTHLY);
                        return;
                    }
                    Iterator<Tuple2<Timestamp, Timestamp>> it = iterable.iterator();
                    if (!it.hasNext()) {
                        collector.collect(HasTimeFrequency.TimeFrequency.MONTHLY);
                    } else {
                        Tuple2<Timestamp, Timestamp> next = it.next();
                        collector.collect(DeepARFeaturesGenerator.generateFrequency((Timestamp) next.f0, (Timestamp) next.f1, this.cnt));
                    }
                }
            }).withBroadcastSet(DataSetUtils.countElementsPerPartition(select.getDataSet()).project(new int[]{1}).sum(0), WordCountUtil.COUNT_COL_NAME);
            Tuple2<DataSet<Tuple2<Integer, Row>>, DataSet<Tuple2<Integer, Long>>> pSort = SortUtils.pSort(select.getDataSet(), findColIndexWithAssertAndHint2);
            String[] outputCols = getOutputCols();
            AkPreconditions.checkState(outputCols != null && (outputCols.length == 2 || outputCols.length == 3));
            final boolean z = outputCols.length == 3;
            TypeInformation<?>[] typeInformationArr = z ? new TypeInformation[]{AlinkTypes.FLOAT_TENSOR, AlinkTypes.FLOAT_TENSOR, AlinkTypes.FLOAT_TENSOR} : new TypeInformation[]{AlinkTypes.FLOAT_TENSOR, AlinkTypes.FLOAT_TENSOR};
            final int intValue = getWindow().intValue();
            final int intValue2 = getStride().intValue();
            setOutput(((DataSet) pSort.f0).partitionByHash(new int[]{0}).mapPartition(new RichMapPartitionFunction<Tuple2<Integer, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.timeseries.DeepARTrainBatchOp.DeepARPreProcessBatchOp.5
                private final TimestampUtil.TimestampToCalendar calendar = new TimestampUtil.TimestampToCalendar();
                private transient HasTimeFrequency.TimeFrequency frequency;

                public void open(Configuration configuration) throws Exception {
                    this.frequency = (HasTimeFrequency.TimeFrequency) getRuntimeContext().getBroadcastVariableWithInitializer("frequency", new BroadcastVariableInitializer<HasTimeFrequency.TimeFrequency, HasTimeFrequency.TimeFrequency>() { // from class: com.alibaba.alink.operator.batch.timeseries.DeepARTrainBatchOp.DeepARPreProcessBatchOp.5.1
                        public HasTimeFrequency.TimeFrequency initializeBroadcastVariable(Iterable<HasTimeFrequency.TimeFrequency> iterable) {
                            return iterable.iterator().next();
                        }

                        /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                        public /* bridge */ /* synthetic */ Object m300initializeBroadcastVariable(Iterable iterable) {
                            return initializeBroadcastVariable((Iterable<HasTimeFrequency.TimeFrequency>) iterable);
                        }
                    });
                }

                public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Row> collector) {
                    long j = 0;
                    ArrayList arrayList = new ArrayList();
                    for (Tuple2<Integer, Row> tuple2 : iterable) {
                        DenseVector denseVector = VectorUtil.getDenseVector(((Row) tuple2.f1).getField(findColIndexWithAssertAndHint));
                        if (denseVector == null) {
                            arrayList.add(Tuple3.of(tuple2.f0, (Object) null, (Object) null));
                        } else {
                            j = denseVector.size();
                            arrayList.add(Tuple3.of(tuple2.f0, FloatTensor.of(TensorUtil.getTensor(denseVector)).reshape2(new Shape(j, 1)), DeepARFeaturesGenerator.generateFromFrequency(this.calendar, this.frequency, (Timestamp) ((Row) tuple2.f1).getField(findColIndexWithAssertAndHint2))));
                        }
                    }
                    arrayList.sort(Comparator.comparing(tuple3 -> {
                        return (Integer) tuple3.f0;
                    }));
                    int size = arrayList.size();
                    FloatTensor[] floatTensorArr = new FloatTensor[intValue];
                    FloatTensor floatTensor = new FloatTensor(new float[intValue]);
                    int i = intValue - intValue2;
                    int i2 = size - intValue;
                    for (int i3 = 0; i3 < j; i3++) {
                        int i4 = 0;
                        while (true) {
                            int i5 = i4;
                            if (i5 < i2) {
                                int i6 = i5 + intValue;
                                FloatTensor floatTensor2 = new FloatTensor(new float[]{0.0f, 0.0f});
                                float f = 0.0f;
                                int i7 = i5;
                                int i8 = 0;
                                while (i7 < i6) {
                                    FloatTensor floatTensor3 = new FloatTensor(new Shape(1));
                                    if (i7 != i5) {
                                        Tuple3 tuple32 = (Tuple3) arrayList.get(i7 - 1);
                                        if (i8 <= i) {
                                            float f2 = ((FloatTensor) tuple32.f1).getFloat(i3, 0);
                                            if (f2 != 0.0f) {
                                                f += 1.0f;
                                            }
                                            floatTensor2.setFloat(floatTensor2.getFloat(0) + f2, 0);
                                        }
                                        floatTensor3.setFloat(((FloatTensor) tuple32.f1).getFloat(i3, 0), 0);
                                    } else {
                                        floatTensor3.setFloat(0.0f, 0);
                                    }
                                    floatTensorArr[i8] = (FloatTensor) Tensor.cat(new FloatTensor[]{floatTensor3, (FloatTensor) ((Tuple3) arrayList.get(i7)).f2}, -1, null);
                                    if (z) {
                                        floatTensor.setFloat(((FloatTensor) ((Tuple3) arrayList.get(i7)).f1).getFloat(i3, 0), i8);
                                    }
                                    i7++;
                                    i8++;
                                }
                                if (f == 0.0f) {
                                    floatTensor2.setFloat(0.0f, 0);
                                } else {
                                    floatTensor2.setFloat((floatTensor2.getFloat(0) / f) + 1.0f, new long[0]);
                                    for (int i9 = 0; i9 < intValue; i9++) {
                                        floatTensorArr[i9].setFloat(floatTensorArr[i9].getFloat(0) / floatTensor2.getFloat(0), 0);
                                        if (z) {
                                            floatTensor.setFloat(floatTensor.getFloat(i9) / floatTensor2.getFloat(0), i9);
                                        }
                                    }
                                }
                                if (z) {
                                    collector.collect(Row.of(new Object[]{Tensor.stack(floatTensorArr, 0, null), floatTensor2, floatTensor}));
                                } else {
                                    collector.collect(Row.of(new Object[]{Tensor.stack(floatTensorArr, 0, null), floatTensor2}));
                                }
                                i4 = i5 + intValue2;
                            }
                        }
                    }
                }
            }).withBroadcastSet(map, "frequency"), outputCols, typeInformationArr);
            setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) map.map(new MapFunction<HasTimeFrequency.TimeFrequency, Row>() { // from class: com.alibaba.alink.operator.batch.timeseries.DeepARTrainBatchOp.DeepARPreProcessBatchOp.6
                public Row map(HasTimeFrequency.TimeFrequency timeFrequency) {
                    return Row.of(new Object[]{timeFrequency});
                }
            }), new String[]{"frequency"}, (TypeInformation<?>[]) new TypeInformation[]{TypeInformation.of(HasTimeFrequency.TimeFrequency.class)})});
            return this;
        }

        @Override // com.alibaba.alink.operator.batch.BatchOperator
        public /* bridge */ /* synthetic */ DeepARPreProcessBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
            return linkFrom((BatchOperator<?>[]) batchOperatorArr);
        }
    }

    public DeepARTrainBatchOp() {
        this(new Params());
    }

    public DeepARTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public DeepARTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        DeepARPreProcessBatchOp linkFrom = ((DeepARPreProcessBatchOp) new DeepARPreProcessBatchOp(getParams().m1495clone()).setOutputCols("tensor", "v", "y").setMLEnvironmentId(getMLEnvironmentId())).linkFrom(checkAndGetFirst(batchOperatorArr));
        HashMap hashMap = new HashMap();
        hashMap.put("window", getWindow());
        hashMap.put("stride", getStride());
        HashMap hashMap2 = new HashMap();
        hashMap2.put("tensorCol", "tensor");
        hashMap2.put("labelCol", "y");
        hashMap2.put("batch_size", String.valueOf(getBatchSize()));
        hashMap2.put("num_epochs", String.valueOf(getNumEpochs()));
        hashMap2.put("model_config", JsonConverter.toJson(hashMap));
        TFTableModelTrainBatchOp linkFrom2 = ((TFTableModelTrainBatchOp) new TFTableModelTrainBatchOp(getParams().m1495clone()).setSelectedCols("tensor", "y").setUserFiles(new String[]{"res:///tf_algos/deepar_entry.py"}).setMainScriptFile("res:///tf_algos/deepar_entry.py").setUserParams(JsonConverter.toJson(hashMap2)).setMLEnvironmentId(getMLEnvironmentId())).linkFrom(linkFrom);
        final Params params = getParams();
        setOutput((DataSet<Row>) linkFrom2.getDataSet().reduceGroup(new RichGroupReduceFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.timeseries.DeepARTrainBatchOp.1
            private transient HasTimeFrequency.TimeFrequency frequency;

            public void open(Configuration configuration) throws Exception {
                this.frequency = (HasTimeFrequency.TimeFrequency) getRuntimeContext().getBroadcastVariableWithInitializer("frequency", new BroadcastVariableInitializer<HasTimeFrequency.TimeFrequency, HasTimeFrequency.TimeFrequency>() { // from class: com.alibaba.alink.operator.batch.timeseries.DeepARTrainBatchOp.1.1
                    public HasTimeFrequency.TimeFrequency initializeBroadcastVariable(Iterable<HasTimeFrequency.TimeFrequency> iterable) {
                        return iterable.iterator().next();
                    }

                    /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                    public /* bridge */ /* synthetic */ Object m298initializeBroadcastVariable(Iterable iterable) {
                        return initializeBroadcastVariable((Iterable<HasTimeFrequency.TimeFrequency>) iterable);
                    }
                });
            }

            public void reduce(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                new DeepARModelDataConverter().save(new DeepARModelDataConverter.DeepARModelData(params.m1495clone().set((ParamInfo<ParamInfo<HasTimeFrequency.TimeFrequency>>) HasTimeFrequency.TIME_FREQUENCY, (ParamInfo<HasTimeFrequency.TimeFrequency>) this.frequency), arrayList), collector);
            }
        }).withBroadcastSet(linkFrom.getSideOutput(0).getDataSet().map(new MapFunction<Row, HasTimeFrequency.TimeFrequency>() { // from class: com.alibaba.alink.operator.batch.timeseries.DeepARTrainBatchOp.2
            public HasTimeFrequency.TimeFrequency map(Row row) throws Exception {
                return (HasTimeFrequency.TimeFrequency) row.getField(0);
            }
        }), "frequency"), new DeepARModelDataConverter().getModelSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ DeepARTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
