package com.alibaba.alink.operator.common.tree.parallelcart.communication;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CommunicateFunction;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.common.linalg.VectorUtil;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Objects;
import java.util.UUID;
import java.util.function.BiConsumer;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/communication/AllReduceT.class */
public class AllReduceT<B> extends CommunicateFunction {
    private static final int TRANSFER_BUFFER_SIZE = 1;
    private static final Logger LOG = LoggerFactory.getLogger(AllReduceT.class);
    private static final long serialVersionUID = -1081251494034109951L;
    private final String bufferName;
    private final String lengthName;
    private final SerializableBiConsumer<B[], B[]> op;
    private final Class<B> elementClass;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/communication/AllReduceT$AllReduceRecv.class */
    public static class AllReduceRecv<T, B> extends RichMapPartitionFunction<Tuple3<Integer, Integer, B[]>, T> {
        private static final long serialVersionUID = 2526497225558653932L;
        private final String bufferName;
        private final String lengthName;
        private final int sessionId;

        AllReduceRecv(String str, String str2, int i) {
            this.bufferName = str;
            this.lengthName = str2;
            this.sessionId = i;
        }

        public void mapPartition(Iterable<Tuple3<Integer, Integer, B[]>> iterable, Collector<T> collector) throws Exception {
            if (getRuntimeContext().getNumberOfParallelSubtasks() == 1) {
                return;
            }
            ComContext comContext = new ComContext(this.sessionId, getIterationRuntimeContext());
            Iterator<Tuple3<Integer, Integer, B[]>> it = iterable.iterator();
            if (it.hasNext()) {
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                AllReduceT.LOG.info("taskId: {}, AllReduceRecv start", Integer.valueOf(indexOfThisSubtask));
                Object[] objArr = (Object[]) comContext.getObj(this.bufferName);
                int intValue = this.lengthName != null ? ((Integer) comContext.getObj(this.lengthName)).intValue() : objArr.length;
                int pieces = AllReduceT.pieces(intValue);
                do {
                    Tuple3<Integer, Integer, B[]> next = it.next();
                    if (((Integer) next.f1).intValue() == pieces - 1) {
                        System.arraycopy(next.f2, 0, objArr, ((Integer) next.f1).intValue() * 1, AllReduceT.lastLen(intValue));
                    } else {
                        System.arraycopy(next.f2, 0, objArr, ((Integer) next.f1).intValue() * 1, 1);
                    }
                } while (it.hasNext());
                AllReduceT.LOG.info("taskId: {}, AllReduceRecv end", Integer.valueOf(indexOfThisSubtask));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/communication/AllReduceT$AllReduceSend.class */
    public static class AllReduceSend<T, B> extends RichMapPartitionFunction<T, Tuple3<Integer, Integer, B[]>> {
        private static final long serialVersionUID = 7772757421079527212L;
        private final String bufferName;
        private final String lengthName;
        private final String transferBufferName;
        private final int sessionId;
        private final Class<B> elementClass;

        AllReduceSend(String str, String str2, String str3, int i, Class<B> cls) {
            this.bufferName = str;
            this.lengthName = str2;
            this.transferBufferName = str3;
            this.sessionId = i;
            this.elementClass = cls;
        }

        public void mapPartition(Iterable<T> iterable, Collector<Tuple3<Integer, Integer, B[]>> collector) throws Exception {
            Object[] objArr;
            if (getRuntimeContext().getNumberOfParallelSubtasks() == 1) {
                return;
            }
            ComContext comContext = new ComContext(this.sessionId, getIterationRuntimeContext());
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
            AllReduceT.LOG.info("taskId: {}, AllReduceSend start", Integer.valueOf(indexOfThisSubtask));
            Object[] objArr2 = (Object[]) comContext.getObj(this.bufferName);
            int intValue = this.lengthName != null ? ((Integer) comContext.getObj(this.lengthName)).intValue() : objArr2.length;
            if (superstepNumber == 1) {
                objArr = (Object[]) Array.newInstance((Class<?>) this.elementClass, 1);
                comContext.putObj(this.transferBufferName, objArr);
            } else {
                objArr = (Object[]) comContext.getObj(this.transferBufferName);
            }
            int pieces = AllReduceT.pieces(intValue);
            AllReduceT.LOG.info("taskId: {}, len: {}, pieces: {}", new Object[]{Integer.valueOf(indexOfThisSubtask), Integer.valueOf(intValue), Integer.valueOf(pieces)});
            DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
            for (int i = 0; i < numberOfParallelSubtasks; i++) {
                int startPos = (int) defaultDistributedInfo.startPos(i, numberOfParallelSubtasks, pieces);
                int localRowCnt = (int) defaultDistributedInfo.localRowCnt(i, numberOfParallelSubtasks, pieces);
                for (int i2 = 0; i2 < localRowCnt; i2++) {
                    int i3 = (startPos + i2) * 1;
                    if (startPos + i2 == pieces - 1) {
                        System.arraycopy(objArr2, i3, objArr, 0, AllReduceT.lastLen(intValue));
                    } else {
                        System.arraycopy(objArr2, i3, objArr, 0, 1);
                    }
                    collector.collect(Tuple3.of(Integer.valueOf(i), Integer.valueOf(startPos + i2), objArr));
                }
            }
            AllReduceT.LOG.info("taskId: {}, AllReduceSend end", Integer.valueOf(indexOfThisSubtask));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/communication/AllReduceT$AllReduceSum.class */
    public static class AllReduceSum<B> extends RichMapPartitionFunction<Tuple3<Integer, Integer, B[]>, Tuple3<Integer, Integer, B[]>> {
        private static final long serialVersionUID = -1513792018030661902L;
        private final String bufferName;
        private final String lengthName;
        private final int sessionId;
        private final SerializableBiConsumer<B[], B[]> op;

        AllReduceSum(String str, String str2, int i, SerializableBiConsumer<B[], B[]> serializableBiConsumer) {
            this.bufferName = str;
            this.lengthName = str2;
            this.sessionId = i;
            this.op = serializableBiConsumer;
        }

        public void mapPartition(Iterable<Tuple3<Integer, Integer, B[]>> iterable, Collector<Tuple3<Integer, Integer, B[]>> collector) {
            if (getRuntimeContext().getNumberOfParallelSubtasks() == 1) {
                return;
            }
            ComContext comContext = new ComContext(this.sessionId, getIterationRuntimeContext());
            Iterator<Tuple3<Integer, Integer, B[]>> it = iterable.iterator();
            if (it.hasNext()) {
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                AllReduceT.LOG.info("taskId: {}, AllReduceSum start", Integer.valueOf(indexOfThisSubtask));
                int pieces = AllReduceT.pieces(this.lengthName != null ? ((Integer) comContext.getObj(this.lengthName)).intValue() : ((Object[]) comContext.getObj(this.bufferName)).length);
                DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
                int startPos = (int) defaultDistributedInfo.startPos(indexOfThisSubtask, numberOfParallelSubtasks, pieces);
                int localRowCnt = (int) defaultDistributedInfo.localRowCnt(indexOfThisSubtask, numberOfParallelSubtasks, pieces);
                ArrayList arrayList = new ArrayList(localRowCnt);
                for (int i = 0; i < localRowCnt; i++) {
                    arrayList.add(null);
                }
                do {
                    Tuple3<Integer, Integer, B[]> next = it.next();
                    int intValue = ((Integer) next.f1).intValue() - startPos;
                    if (arrayList.get(intValue) == null) {
                        arrayList.set(intValue, next.f2);
                    } else {
                        this.op.accept(arrayList.get(intValue), next.f2);
                    }
                } while (it.hasNext());
                for (int i2 = 0; i2 < numberOfParallelSubtasks; i2++) {
                    for (int i3 = 0; i3 < localRowCnt; i3++) {
                        collector.collect(Tuple3.of(Integer.valueOf(i2), Integer.valueOf(startPos + i3), arrayList.get(i3)));
                    }
                }
                AllReduceT.LOG.info("taskId: {}, AllReduceSum end", Integer.valueOf(indexOfThisSubtask));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/communication/AllReduceT$SerializableBiConsumer.class */
    public interface SerializableBiConsumer<T, U> extends BiConsumer<T, U>, Serializable {
        default SerializableBiConsumer<T, U> andThen(SerializableBiConsumer<? super T, ? super U> serializableBiConsumer) {
            Objects.requireNonNull(serializableBiConsumer);
            return (obj, obj2) -> {
                accept(obj, obj2);
                serializableBiConsumer.accept(obj, obj2);
            };
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case 1044689245:
                    if (implMethodName.equals("lambda$andThen$24672343$1")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case VectorUtil.VectorSerialType.DENSE_VECTOR /* 0 */:
                    if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/alibaba/alink/operator/common/tree/parallelcart/communication/AllReduceT$SerializableBiConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/alibaba/alink/operator/common/tree/parallelcart/communication/AllReduceT$SerializableBiConsumer") && serializedLambda.getImplMethodSignature().equals("(Lcom/alibaba/alink/operator/common/tree/parallelcart/communication/AllReduceT$SerializableBiConsumer;Ljava/lang/Object;Ljava/lang/Object;)V")) {
                        SerializableBiConsumer serializableBiConsumer = (SerializableBiConsumer) serializedLambda.getCapturedArg(0);
                        SerializableBiConsumer serializableBiConsumer2 = (SerializableBiConsumer) serializedLambda.getCapturedArg(1);
                        return (obj, obj2) -> {
                            accept(obj, obj2);
                            serializableBiConsumer2.accept(obj, obj2);
                        };
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }
    }

    public AllReduceT(String str, String str2, SerializableBiConsumer<B[], B[]> serializableBiConsumer, Class<B> cls) {
        this.bufferName = str;
        this.lengthName = str2;
        this.op = serializableBiConsumer;
        this.elementClass = cls;
    }

    public static <T, B> DataSet<T> allReduce(DataSet<T> dataSet, String str, String str2, SerializableBiConsumer<B[], B[]> serializableBiConsumer, int i, Class<B> cls) {
        String uuid = UUID.randomUUID().toString();
        TypeInformation OBJECT_ARRAY = Types.OBJECT_ARRAY(Types.GENERIC(cls));
        return dataSet.mapPartition(new AllReduceSend(str, str2, uuid, i, cls)).withBroadcastSet(dataSet, "barrier").returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.INT, OBJECT_ARRAY})).name("AllReduceSend").partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.tree.parallelcart.communication.AllReduceT.2
            private static final long serialVersionUID = -4605373044396957398L;

            public int partition(Integer num, int i2) {
                return num.intValue();
            }
        }, 0).name("AllReduceBroadcastRaw").mapPartition(new AllReduceSum(str, str2, i, serializableBiConsumer)).returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.INT, OBJECT_ARRAY})).name("AllReduceSum").partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.tree.parallelcart.communication.AllReduceT.1
            private static final long serialVersionUID = -5499691435112716956L;

            public int partition(Integer num, int i2) {
                return num.intValue();
            }
        }, 0).name("AllReduceBroadcastSum").mapPartition(new AllReduceRecv(str, str2, i)).returns(dataSet.getType()).name("AllReduceRecv");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int pieces(int i) {
        int i2 = i / 1;
        return i % 1 == 0 ? i2 : i2 + 1;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int lastLen(int i) {
        int i2 = i % 1;
        if (i2 == 0) {
            return 1;
        }
        return i2;
    }

    @Override // com.alibaba.alink.common.comqueue.CommunicateFunction
    public <T> DataSet<T> communicateWith(DataSet<T> dataSet, int i) {
        return allReduce(dataSet, this.bufferName, this.lengthName, this.op, i, this.elementClass);
    }
}
