package com.alibaba.alink.common.comqueue;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.comqueue.BaseComQueue;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.batch.BatchOperator;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/common/comqueue/BaseComQueue.class */
public class BaseComQueue<Q extends BaseComQueue<Q>> implements Serializable {
    private static final long serialVersionUID = -4727279909132083484L;
    private CompareCriterionFunction compareCriterion;
    private CompleteResultFunction completeResult;
    private transient DataSet<byte[]> cacheDataRel;
    private transient ExecutionEnvironment executionEnvironment;
    private final List<ComQueueItem> queue = new ArrayList();
    private final int sessionId = SessionSharedObjs.getNewSessionId();
    private int maxIter = Integer.MAX_VALUE;
    private transient List<String> cacheDataObjNames = new ArrayList();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/common/comqueue/BaseComQueue$DistributeData.class */
    public static class DistributeData extends ComputeFunction {
        private static final long serialVersionUID = -1105584217517972610L;
        private final List<String> cacheDataObjNames;
        private final int sessionId;

        DistributeData(List<String> list, int i) {
            this.cacheDataObjNames = list;
            this.sessionId = i;
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            if (comContext.getStepNo() != 1) {
                return;
            }
            SessionSharedObjs.distributeCachedData(this.cacheDataObjNames, this.sessionId, comContext.getTaskId());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/common/comqueue/BaseComQueue$PutCachedData.class */
    public static class PutCachedData<T> extends RichMapPartitionFunction<T, byte[]> {
        private static final long serialVersionUID = -6356063476350424243L;
        private final String key;
        private final int sessionId;

        PutCachedData(String str, int i) {
            this.key = str;
            this.sessionId = i;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
        }

        public void mapPartition(Iterable<T> iterable, Collector<byte[]> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            arrayList.getClass();
            iterable.forEach(arrayList::add);
            SessionSharedObjs.cachePartitionedData(this.key, this.sessionId, arrayList);
        }
    }

    private Q thisAsQ() {
        return this;
    }

    public Q add(ComQueueItem comQueueItem) {
        this.queue.add(comQueueItem);
        return thisAsQ();
    }

    int getMaxIter() {
        return this.maxIter;
    }

    List<ComQueueItem> getQueue() {
        return this.queue;
    }

    CompareCriterionFunction getCompareCriterion() {
        return this.compareCriterion;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Q setCompareCriterionOfNode0(CompareCriterionFunction compareCriterionFunction) {
        this.compareCriterion = compareCriterionFunction;
        return thisAsQ();
    }

    public Q closeWith(CompleteResultFunction completeResultFunction) {
        this.completeResult = completeResultFunction;
        return thisAsQ();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Q setMaxIter(int i) {
        this.maxIter = i;
        return thisAsQ();
    }

    public <T> Q initWithPartitionedData(String str, DataSet<T> dataSet) {
        createRelationshipAndCachedData(dataSet, str);
        return thisAsQ();
    }

    public <T> Q initWithBroadcastData(String str, DataSet<T> dataSet) {
        return initWithPartitionedData(str, broadcastDataSet(dataSet));
    }

    public Q initWithMLSessionId(Long l) {
        this.executionEnvironment = MLEnvironmentFactory.get(l).getExecutionEnvironment();
        return thisAsQ();
    }

    public DataSet<Row> exec() {
        DataSet<byte[]> closeWith;
        this.queue.add(0, new DistributeData(this.cacheDataObjNames, this.sessionId));
        optimize();
        if (this.executionEnvironment == null) {
            if (this.cacheDataRel == null) {
                this.executionEnvironment = MLEnvironmentFactory.getDefault().getExecutionEnvironment();
            } else {
                this.executionEnvironment = BatchOperator.getExecutionEnvironmentFromDataSets(this.cacheDataRel);
            }
        }
        Operator iterate = loopStartDataSet(this.executionEnvironment).iterate(this.maxIter);
        Operator operator = iterate;
        for (ComQueueItem comQueueItem : this.queue) {
            if (comQueueItem instanceof CommunicateFunction) {
                operator = ((CommunicateFunction) comQueueItem).communicateWith(operator, this.sessionId);
            } else {
                if (!(comQueueItem instanceof ComputeFunction)) {
                    throw new AkUnsupportedOperationException("Unsupported op in iterative queue.");
                }
                final ComputeFunction computeFunction = (ComputeFunction) comQueueItem;
                operator = operator.mapPartition(new RichMapPartitionFunction<byte[], byte[]>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.1
                    private static final long serialVersionUID = -6617692288474518056L;

                    public void mapPartition(Iterable<byte[]> iterable, Collector<byte[]> collector) {
                        computeFunction.calc(new ComContext(BaseComQueue.this.sessionId, getIterationRuntimeContext()));
                    }
                }).withBroadcastSet(operator, "barrier").name(comQueueItem instanceof ChainedComputation ? ((ChainedComputation) comQueueItem).name() : "computation@" + computeFunction.getClass().getSimpleName());
            }
        }
        if (null == this.compareCriterion) {
            closeWith = iterate.closeWith(operator.mapPartition(new RichMapPartitionFunction<byte[], byte[]>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.2
                private static final long serialVersionUID = -5391702186659779224L;

                public void mapPartition(Iterable<byte[]> iterable, Collector<byte[]> collector) throws Exception {
                    if (getIterationRuntimeContext().getSuperstepNumber() == BaseComQueue.this.maxIter) {
                        List<Row> calc = BaseComQueue.this.completeResult.calc(new ComContext(BaseComQueue.this.sessionId, getIterationRuntimeContext()));
                        if (null == calc) {
                            return;
                        }
                        Iterator<Row> it = calc.iterator();
                        while (it.hasNext()) {
                            collector.collect(SerializationUtils.serialize(it.next()));
                        }
                    }
                }
            }).withBroadcastSet(operator, "barrier").name("genNewModel"));
        } else {
            Operator name = operator.mapPartition(new RichMapPartitionFunction<byte[], Boolean>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.3
                private static final long serialVersionUID = 6625968106516906392L;

                public void mapPartition(Iterable<byte[]> iterable, Collector<Boolean> collector) throws Exception {
                    if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                        if (BaseComQueue.this.compareCriterion.calc(new ComContext(BaseComQueue.this.sessionId, getIterationRuntimeContext()))) {
                            return;
                        }
                        collector.collect(false);
                    }
                }
            }).withBroadcastSet(operator, "barrier").name("genCriterion");
            closeWith = iterate.closeWith(operator.mapPartition(new RichMapPartitionFunction<byte[], byte[]>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.4
                private static final long serialVersionUID = -2243669394358656436L;
                boolean criterion;

                public void open(Configuration configuration) {
                    this.criterion = ((Boolean) getRuntimeContext().getBroadcastVariableWithInitializer("criterion", new BroadcastVariableInitializer<Boolean, Boolean>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.4.1
                        public Boolean initializeBroadcastVariable(Iterable<Boolean> iterable) {
                            if (iterable.iterator().hasNext()) {
                                return iterable.iterator().next();
                            }
                            return true;
                        }

                        /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                        public /* bridge */ /* synthetic */ Object m17initializeBroadcastVariable(Iterable iterable) {
                            return initializeBroadcastVariable((Iterable<Boolean>) iterable);
                        }
                    })).booleanValue();
                }

                public void mapPartition(Iterable<byte[]> iterable, Collector<byte[]> collector) {
                    List<Row> calc;
                    ComContext comContext = new ComContext(BaseComQueue.this.sessionId, getIterationRuntimeContext());
                    if ((getIterationRuntimeContext().getSuperstepNumber() == BaseComQueue.this.maxIter || this.criterion) && null != (calc = BaseComQueue.this.completeResult.calc(comContext))) {
                        Iterator<Row> it = calc.iterator();
                        while (it.hasNext()) {
                            collector.collect(SerializationUtils.serialize(it.next()));
                        }
                    }
                }
            }).withBroadcastSet(operator, "barrier").withBroadcastSet(name, "criterion").name("genNewModel"), name);
        }
        return serializeModel(clearObjs(closeWith));
    }

    public String toString() {
        HashMap hashMap = new HashMap();
        hashMap.put("queue", this.queue.stream().map(comQueueItem -> {
            return comQueueItem.getClass().getSimpleName();
        }).collect(Collectors.joining(",")));
        hashMap.put("sessionId", Integer.valueOf(this.sessionId));
        hashMap.put("maxIter", Integer.valueOf(this.maxIter));
        hashMap.put("compareCriterion", this.compareCriterion == null ? null : this.compareCriterion.getClass().getSimpleName());
        hashMap.put("completeResult", this.completeResult == null ? null : this.completeResult.getClass().getSimpleName());
        return JsonConverter.toJson(hashMap);
    }

    private static DataSet<Row> serializeModel(DataSet<byte[]> dataSet) {
        return dataSet.map(new MapFunction<byte[], Row>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.5
            private static final long serialVersionUID = 7383520679708122544L;

            public Row map(byte[] bArr) {
                return (Row) SerializationUtils.deserialize(bArr);
            }
        }).name("serializeModel");
    }

    private static <T> DataSet<T> broadcastDataSet(DataSet<T> dataSet) {
        return expandDataSet2MaxParallelism(dataSet).mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, T>>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.8
            private static final long serialVersionUID = -4649163203694740662L;

            public void mapPartition(Iterable<T> iterable, Collector<Tuple2<Integer, T>> collector) throws Exception {
                int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                for (T t : iterable) {
                    for (int i = 0; i < numberOfParallelSubtasks; i++) {
                        collector.collect(Tuple2.of(Integer.valueOf(i), t));
                    }
                }
            }
        }).returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, dataSet.getType()})).name("sharedDataBroadcast").partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.7
            private static final long serialVersionUID = -6692961321999695162L;

            public int partition(Integer num, int i) {
                return num.intValue() % i;
            }
        }, 0).name("sharedDataPartition").mapPartition(new RichMapPartitionFunction<Tuple2<Integer, T>, T>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.6
            private static final long serialVersionUID = -4348660942756396179L;

            public void mapPartition(Iterable<Tuple2<Integer, T>> iterable, Collector<T> collector) {
                Iterator<Tuple2<Integer, T>> it = iterable.iterator();
                while (it.hasNext()) {
                    collector.collect(it.next().f1);
                }
            }
        }).returns(dataSet.getType()).name("sharedDataFly");
    }

    private static <T> DataSet<T> expandDataSet2MaxParallelism(DataSet<T> dataSet) {
        return dataSet.map(new RichMapFunction<T, Tuple2<Integer, T>>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.11
            private static final long serialVersionUID = -3563752831074380126L;

            public Tuple2<Integer, T> map(T t) throws Exception {
                return Tuple2.of(Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask()), t);
            }

            /* JADX WARN: Multi-variable type inference failed */
            /* renamed from: map, reason: collision with other method in class */
            public /* bridge */ /* synthetic */ Object m16map(Object obj) throws Exception {
                return map((AnonymousClass11<T>) obj);
            }
        }).returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, dataSet.getType()})).name("appendTaskId2Data").partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.10
            private static final long serialVersionUID = -8823979305373109564L;

            public int partition(Integer num, int i) {
                return num.intValue() % i;
            }
        }, 0).name("partitionData2Task").map(new MapFunction<Tuple2<Integer, T>, T>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.9
            private static final long serialVersionUID = 7735842319931188500L;

            public T map(Tuple2<Integer, T> tuple2) {
                return (T) tuple2.f1;
            }
        }).returns(dataSet.getType()).name("projectData2Raw");
    }

    private DataSet<byte[]> loopStartDataSet(ExecutionEnvironment executionEnvironment) {
        MapPartitionOperator name = executionEnvironment.fromElements(new Integer[]{1}).rebalance().mapPartition(new MapPartitionFunction<Integer, byte[]>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.12
            private static final long serialVersionUID = 1605194585509760448L;

            public void mapPartition(Iterable<Integer> iterable, Collector<byte[]> collector) {
            }
        }).name("iterInitialize");
        if (this.cacheDataRel != null) {
            name = name.withBroadcastSet(this.cacheDataRel, "rel");
        }
        return name;
    }

    private DataSet<byte[]> clearObjs(DataSet<byte[]> dataSet) {
        final int i = this.sessionId;
        return dataSet.map(new MapFunction<byte[], byte[]>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.15
            private static final long serialVersionUID = 6060666608672449498L;

            public byte[] map(byte[] bArr) {
                return bArr;
            }
        }).withBroadcastSet(expandDataSet2MaxParallelism(BatchOperator.getExecutionEnvironmentFromDataSets(dataSet).fromElements(new Integer[]{0})).mapPartition(new RichMapPartitionFunction<Integer, byte[]>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.13
            private static final long serialVersionUID = -7819774126101954367L;

            public void mapPartition(Iterable<Integer> iterable, Collector<byte[]> collector) {
                SessionSharedObjs.clear(i);
            }
        }).withBroadcastSet(dataSet.mapPartition(new MapPartitionFunction<byte[], byte[]>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.14
            private static final long serialVersionUID = 570124206050744389L;

            public void mapPartition(Iterable<byte[]> iterable, Collector<byte[]> collector) throws Exception {
            }
        }), "barrier"), "barrier").name("clearReturn");
    }

    private <T> void createRelationshipAndCachedData(DataSet<T> dataSet, String str) {
        int i = this.sessionId;
        if (this.cacheDataRel == null) {
            this.cacheDataRel = clearObjs(dataSet.mapPartition(new MapPartitionFunction<T, byte[]>() { // from class: com.alibaba.alink.common.comqueue.BaseComQueue.16
                private static final long serialVersionUID = 5119252579498807853L;

                public void mapPartition(Iterable<T> iterable, Collector<byte[]> collector) throws Exception {
                }
            }));
        }
        this.cacheDataRel = dataSet.mapPartition(new PutCachedData(str, i)).withBroadcastSet(this.cacheDataRel, "rel").name("cachedDataRel@" + str);
        this.cacheDataObjNames.add(str);
    }

    private void optimize() {
        if (this.queue.isEmpty()) {
            return;
        }
        int i = 0;
        for (int i2 = 1; i2 < this.queue.size(); i2++) {
            ComQueueItem comQueueItem = this.queue.get(i);
            ComQueueItem comQueueItem2 = this.queue.get(i2);
            if (!(comQueueItem2 instanceof ComputeFunction) || !(comQueueItem instanceof ComputeFunction)) {
                i++;
                this.queue.set(i, comQueueItem2);
            } else if (comQueueItem instanceof ChainedComputation) {
                this.queue.set(i, ((ChainedComputation) comQueueItem).add((ComputeFunction) comQueueItem2));
            } else {
                this.queue.set(i, new ChainedComputation().add((ComputeFunction) comQueueItem).add((ComputeFunction) comQueueItem2));
            }
        }
        this.queue.subList(i + 1, this.queue.size()).clear();
    }
}
