package com.alibaba.alink.operator.common.tree;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CommunicateFunction;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/Bcast.class */
public class Bcast<DT> extends CommunicateFunction {
    private final String bufferName;
    private final int root;
    private final TypeInformation<?> type;

    public Bcast(String str, int i, TypeInformation<?> typeInformation) {
        this.bufferName = str;
        this.root = i;
        this.type = typeInformation;
    }

    @Override // com.alibaba.alink.common.comqueue.CommunicateFunction
    public <T> DataSet<T> communicateWith(DataSet<T> dataSet, final int i) {
        final String str = this.bufferName;
        final int i2 = this.root;
        return dataSet.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, DT>>() { // from class: com.alibaba.alink.operator.common.tree.Bcast.3
            public void mapPartition(Iterable<T> iterable, Collector<Tuple2<Integer, DT>> collector) {
                ComContext comContext = new ComContext(i, getIterationRuntimeContext());
                if (comContext.getTaskId() != i2) {
                    return;
                }
                for (T t : (Iterable) comContext.getObj(str)) {
                    for (int i3 = 0; i3 < comContext.getNumTask(); i3++) {
                        collector.collect(Tuple2.of(Integer.valueOf(i3), t));
                    }
                }
            }
        }).returns(Types.TUPLE(new TypeInformation[]{Types.INT, this.type})).withBroadcastSet(dataSet, "barrier").partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.tree.Bcast.2
            public int partition(Integer num, int i3) {
                return num.intValue();
            }
        }, 0).mapPartition(new RichMapPartitionFunction<Tuple2<Integer, DT>, T>() { // from class: com.alibaba.alink.operator.common.tree.Bcast.1
            public void mapPartition(Iterable<Tuple2<Integer, DT>> iterable, Collector<T> collector) {
                ComContext comContext = new ComContext(i, getIterationRuntimeContext());
                if (comContext.getTaskId() == i2) {
                    return;
                }
                ArrayList arrayList = new ArrayList();
                Iterator<Tuple2<Integer, DT>> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next().f1);
                }
                comContext.putObj(str, arrayList);
            }
        }).returns(dataSet.getType());
    }
}
