package com.alibaba.alink.operator.batch.timeseries;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamCond;
import com.alibaba.alink.common.annotation.ParamMutexRule;
import com.alibaba.alink.common.annotation.ParamMutexRules;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.tensor.FloatTensor;
import com.alibaba.alink.common.linalg.tensor.Tensor;
import com.alibaba.alink.common.linalg.tensor.TensorUtil;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp;
import com.alibaba.alink.operator.common.dataproc.SortUtils;
import com.alibaba.alink.operator.common.tree.Preprocessing;
import com.alibaba.alink.params.timeseries.LSTNetPreProcessParams;
import com.alibaba.alink.params.timeseries.LSTNetTrainParams;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "timeCol", allowedTypeCollections = {TypeCollections.TIMESTAMP_TYPES}), @ParamSelectColumnSpec(name = "selectedCol"), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES})})
@NameCn("LSTNet训练")
@ParamMutexRules({@ParamMutexRule(name = "vectorCol", type = ParamMutexRule.ActionType.DISABLE, cond = @ParamCond(name = "selectedCol", type = ParamCond.CondType.WHEN_NOT_NULL)), @ParamMutexRule(name = "selectedCol", type = ParamMutexRule.ActionType.DISABLE, cond = @ParamCond(name = "vectorCol", type = ParamCond.CondType.WHEN_NOT_NULL))})
@NameEn("LSTNet Training")
/* loaded from: input_file:com/alibaba/alink/operator/batch/timeseries/LSTNetTrainBatchOp.class */
public class LSTNetTrainBatchOp extends BatchOperator<LSTNetTrainBatchOp> implements LSTNetTrainParams<LSTNetTrainBatchOp> {

    /* JADX INFO: Access modifiers changed from: private */
    @Internal
    /* loaded from: input_file:com/alibaba/alink/operator/batch/timeseries/LSTNetTrainBatchOp$LSTNetPreProcessBatchOp.class */
    public static class LSTNetPreProcessBatchOp extends BatchOperator<LSTNetPreProcessBatchOp> implements LSTNetPreProcessParams<LSTNetPreProcessBatchOp> {
        public LSTNetPreProcessBatchOp() {
            this(new Params());
        }

        public LSTNetPreProcessBatchOp(Params params) {
            super(params);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // com.alibaba.alink.operator.batch.BatchOperator
        public LSTNetPreProcessBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
            BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
            String vectorCol = getParams().contains(VECTOR_COL) ? getVectorCol() : getSelectedCol();
            AkPreconditions.checkNotNull(vectorCol);
            String timeCol = getTimeCol();
            BatchOperator<?> select = Preprocessing.select(checkAndGetFirst, timeCol, vectorCol);
            final int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(select.getColNames(), vectorCol);
            int findColIndexWithAssertAndHint2 = TableUtil.findColIndexWithAssertAndHint(select.getColNames(), timeCol);
            final int intValue = getWindow().intValue();
            final int intValue2 = getHorizon().intValue();
            Tuple2<DataSet<Tuple2<Integer, Row>>, DataSet<Tuple2<Integer, Long>>> pSort = SortUtils.pSort(select.getDataSet(), findColIndexWithAssertAndHint2);
            String[] outputCols = getOutputCols();
            AkPreconditions.checkState(outputCols != null && (outputCols.length == 1 || outputCols.length == 2));
            final boolean z = outputCols.length == 2;
            setOutput(((DataSet) pSort.f0).partitionByHash(new int[]{0}).mapPartition(new MapPartitionFunction<Tuple2<Integer, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.timeseries.LSTNetTrainBatchOp.LSTNetPreProcessBatchOp.1
                public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Row> collector) {
                    ArrayList arrayList = new ArrayList();
                    for (Tuple2<Integer, Row> tuple2 : iterable) {
                        arrayList.add(Tuple2.of(tuple2.f0, FloatTensor.of(TensorUtil.getTensor(((Row) tuple2.f1).getField(findColIndexWithAssertAndHint)))));
                    }
                    arrayList.sort(Comparator.comparing(tuple22 -> {
                        return (Integer) tuple22.f0;
                    }));
                    int size = arrayList.size();
                    FloatTensor[] floatTensorArr = new FloatTensor[intValue];
                    for (int i = (intValue + intValue2) - 1; i < size; i++) {
                        int i2 = (i - intValue2) + 1;
                        int i3 = i2 - intValue;
                        int i4 = 0;
                        while (i3 < i2) {
                            floatTensorArr[i4] = (FloatTensor) ((Tuple2) arrayList.get(i3)).f1;
                            i3++;
                            i4++;
                        }
                        if (z) {
                            collector.collect(Row.of(new Object[]{Tensor.stack(floatTensorArr, 0, null), ((Tuple2) arrayList.get(i)).f1}));
                        } else {
                            collector.collect(Row.of(new Object[]{Tensor.stack(floatTensorArr, 0, null)}));
                        }
                    }
                }
            }), outputCols, z ? new TypeInformation[]{AlinkTypes.FLOAT_TENSOR, AlinkTypes.FLOAT_TENSOR} : new TypeInformation[]{AlinkTypes.FLOAT_TENSOR});
            return this;
        }

        @Override // com.alibaba.alink.operator.batch.BatchOperator
        public /* bridge */ /* synthetic */ LSTNetPreProcessBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
            return linkFrom((BatchOperator<?>[]) batchOperatorArr);
        }
    }

    public LSTNetTrainBatchOp() {
        this(new Params());
    }

    public LSTNetTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public LSTNetTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        LSTNetPreProcessBatchOp linkFrom = ((LSTNetPreProcessBatchOp) new LSTNetPreProcessBatchOp(getParams().m1495clone()).setOutputCols("tensor", "y").setMLEnvironmentId(getMLEnvironmentId())).linkFrom(checkAndGetFirst(batchOperatorArr));
        HashMap hashMap = new HashMap();
        hashMap.put("window", getWindow());
        hashMap.put("horizon", getHorizon());
        HashMap hashMap2 = new HashMap();
        hashMap2.put("tensorCol", "tensor");
        hashMap2.put("labelCol", "y");
        hashMap2.put("batch_size", String.valueOf(getBatchSize()));
        hashMap2.put("num_epochs", String.valueOf(getNumEpochs()));
        hashMap2.put("model_config", JsonConverter.toJson(hashMap));
        setOutputTable(((TFTableModelTrainBatchOp) new TFTableModelTrainBatchOp(getParams().m1495clone()).setSelectedCols("tensor", "y").setUserFiles(new String[]{"res:///tf_algos/lstnet_entry.py"}).setMainScriptFile("res:///tf_algos/lstnet_entry.py").setUserParams(JsonConverter.toJson(hashMap2)).linkFrom(linkFrom).setMLEnvironmentId(getMLEnvironmentId())).getOutputTable());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ LSTNetTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
