package com.alibaba.alink.operator.common.tree.parallelcart.communication;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CommunicateFunction;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.common.io.directreader.DistributedInfo;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.parallelcart.fakeserializer.FakeDoublePrimitiveArrayType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.UUID;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/communication/ReduceScatter.class */
public class ReduceScatter extends CommunicateFunction {
    public static final int TRANSFER_BUFFER_SIZE = 4094;
    private static final Logger LOG = LoggerFactory.getLogger(ReduceScatter.class);
    private static final long serialVersionUID = -5868324335725015627L;
    private final String bufferName;
    private final String recvbufName;
    private final String recvcntsName;
    private final AllReduce.SerializableBiConsumer<double[], double[]> op;

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/communication/ReduceScatter$ReduceSend.class */
    private static class ReduceSend<T> extends RichMapPartitionFunction<T, Tuple3<Integer, Integer, double[]>> {
        private static final long serialVersionUID = -1122497764898022914L;
        private final String bufferName;
        private final String transferBufferName;
        private final String recvcntsName;
        private final int sessionId;
        private Tuple3<Integer, Integer, double[]> outputBuf;
        private DistributedInfo distributedInfo;

        ReduceSend(String str, String str2, String str3, int i) {
            this.bufferName = str;
            this.transferBufferName = str3;
            this.recvcntsName = str2;
            this.sessionId = i;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.outputBuf = new Tuple3<>();
            this.distributedInfo = new DefaultDistributedInfo();
        }

        public void mapPartition(Iterable<T> iterable, Collector<Tuple3<Integer, Integer, double[]>> collector) throws Exception {
            double[] dArr;
            if (getRuntimeContext().getNumberOfParallelSubtasks() == 1) {
                return;
            }
            ComContext comContext = new ComContext(this.sessionId, getIterationRuntimeContext());
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
            ReduceScatter.LOG.info("taskId: {}, ReduceSend start", Integer.valueOf(indexOfThisSubtask));
            double[] dArr2 = (double[]) comContext.getObj(this.bufferName);
            int[] iArr = (int[]) comContext.getObj(this.recvcntsName);
            int i = 0;
            for (int i2 : iArr) {
                i += ReduceScatter.pieces(i2);
            }
            if (superstepNumber == 1) {
                dArr = new double[ReduceScatter.TRANSFER_BUFFER_SIZE];
                comContext.putObj(this.transferBufferName, dArr);
            } else {
                dArr = (double[]) comContext.getObj(this.transferBufferName);
            }
            int i3 = 0;
            int i4 = 0;
            for (int i5 = 0; i5 < iArr.length; i5++) {
                int pieces = ReduceScatter.pieces(iArr[i5]);
                for (int i6 = 0; i6 < pieces; i6++) {
                    if (i6 == pieces - 1) {
                        int lastLen = ReduceScatter.lastLen(iArr[i5]);
                        System.arraycopy(dArr2, i4, dArr, 0, lastLen);
                        i4 += lastLen;
                    } else {
                        System.arraycopy(dArr2, i4, dArr, 0, ReduceScatter.TRANSFER_BUFFER_SIZE);
                        i4 += ReduceScatter.TRANSFER_BUFFER_SIZE;
                    }
                    this.outputBuf.setFields(Integer.valueOf((int) this.distributedInfo.where(i3, comContext.getNumTask(), i)), Integer.valueOf(i3), dArr);
                    collector.collect(this.outputBuf);
                    i3++;
                }
            }
            ReduceScatter.LOG.info("taskId: {}, ReduceSend end", Integer.valueOf(indexOfThisSubtask));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/communication/ReduceScatter$ReduceSum.class */
    private static class ReduceSum extends RichMapPartitionFunction<Tuple3<Integer, Integer, double[]>, Tuple3<Integer, Integer, double[]>> {
        private static final long serialVersionUID = 8374168645659944661L;
        private final String recvcntsName;
        private final int sessionId;
        private final AllReduce.SerializableBiConsumer<double[], double[]> op;
        private int[] offset;
        ArrayList<double[]> sum;
        Tuple3<Integer, Integer, double[]> outBuf;
        DistributedInfo distributedInfo = new DefaultDistributedInfo();

        ReduceSum(String str, int i, AllReduce.SerializableBiConsumer<double[], double[]> serializableBiConsumer) {
            this.sessionId = i;
            this.recvcntsName = str;
            this.op = serializableBiConsumer;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void mapPartition(Iterable<Tuple3<Integer, Integer, double[]>> iterable, Collector<Tuple3<Integer, Integer, double[]>> collector) {
            if (getRuntimeContext().getNumberOfParallelSubtasks() == 1) {
                return;
            }
            ComContext comContext = new ComContext(this.sessionId, getIterationRuntimeContext());
            if (comContext.getStepNo() == 1) {
                this.offset = new int[getRuntimeContext().getNumberOfParallelSubtasks()];
                this.sum = new ArrayList<>();
                this.outBuf = new Tuple3<>();
            }
            Iterator<Tuple3<Integer, Integer, double[]>> it = iterable.iterator();
            if (it.hasNext()) {
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                ReduceScatter.LOG.info("taskId: {}, ReduceSum start", Integer.valueOf(indexOfThisSubtask));
                int[] iArr = (int[]) comContext.getObj(this.recvcntsName);
                int i = 0;
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    this.offset[i2] = i;
                    i += ReduceScatter.pieces(iArr[i2]);
                }
                int startPos = (int) this.distributedInfo.startPos(indexOfThisSubtask, numberOfParallelSubtasks, i);
                int localRowCnt = (int) this.distributedInfo.localRowCnt(indexOfThisSubtask, numberOfParallelSubtasks, i);
                ReduceScatter.LOG.info("taskId: {}, ReduceSum cnt: {}", Integer.valueOf(indexOfThisSubtask), Integer.valueOf(localRowCnt));
                int size = this.sum.size();
                if (size < localRowCnt) {
                    this.sum.ensureCapacity(localRowCnt * 2);
                    for (int i3 = size; i3 < localRowCnt * 2; i3++) {
                        this.sum.add(new double[ReduceScatter.TRANSFER_BUFFER_SIZE]);
                    }
                }
                for (int i4 = 0; i4 < localRowCnt; i4++) {
                    Arrays.fill(this.sum.get(i4), Criteria.INVALID_GAIN);
                }
                long j = 0;
                do {
                    Tuple3<Integer, Integer, double[]> next = it.next();
                    long currentTimeMillis = System.currentTimeMillis();
                    this.op.accept(this.sum.get(((Integer) next.f1).intValue() - startPos), next.f2);
                    j += System.currentTimeMillis() - currentTimeMillis;
                } while (it.hasNext());
                ReduceScatter.LOG.info("taskId: {}, ReduceSum time: {}", Integer.valueOf(indexOfThisSubtask), Long.valueOf(j));
                long j2 = 0;
                for (int i5 = 0; i5 < localRowCnt; i5++) {
                    long currentTimeMillis2 = System.currentTimeMillis();
                    int binarySearch = Arrays.binarySearch(this.offset, startPos + i5);
                    int i6 = binarySearch >= 0 ? binarySearch : (-binarySearch) - 2;
                    while (i6 < numberOfParallelSubtasks && iArr[i6] == 0) {
                        i6++;
                    }
                    if (iArr[i6] == 0) {
                        throw new IllegalStateException("It should not be empty in the scatter task. Maybe it is an issue.");
                    }
                    j2 += System.currentTimeMillis() - currentTimeMillis2;
                    this.outBuf.setFields(Integer.valueOf(i6), Integer.valueOf((startPos + i5) - this.offset[i6]), this.sum.get(i5));
                    collector.collect(this.outBuf);
                }
                ReduceScatter.LOG.info("taskId: {}, ScatterSend time: {}", Integer.valueOf(indexOfThisSubtask), Long.valueOf(j2));
                ReduceScatter.LOG.info("taskId: {}, ReduceSum end", Integer.valueOf(indexOfThisSubtask));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/communication/ReduceScatter$ScatterRecv.class */
    private static class ScatterRecv<T> extends RichMapPartitionFunction<Tuple3<Integer, Integer, double[]>, T> {
        private static final long serialVersionUID = 5441679946742930637L;
        private final String bufferName;
        private final String recvcntsName;
        private final int sessionId;

        ScatterRecv(String str, String str2, int i) {
            this.bufferName = str;
            this.recvcntsName = str2;
            this.sessionId = i;
        }

        public void mapPartition(Iterable<Tuple3<Integer, Integer, double[]>> iterable, Collector<T> collector) throws Exception {
            if (getRuntimeContext().getNumberOfParallelSubtasks() == 1) {
                return;
            }
            ComContext comContext = new ComContext(this.sessionId, getIterationRuntimeContext());
            Iterator<Tuple3<Integer, Integer, double[]>> it = iterable.iterator();
            if (it.hasNext()) {
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                ReduceScatter.LOG.info("taskId: {}, ScatterRecv start", Integer.valueOf(indexOfThisSubtask));
                double[] dArr = (double[]) comContext.getObj(this.bufferName);
                int i = ((int[]) comContext.getObj(this.recvcntsName))[comContext.getTaskId()];
                int pieces = ReduceScatter.pieces(i);
                do {
                    Tuple3<Integer, Integer, double[]> next = it.next();
                    if (((Integer) next.f1).intValue() == pieces - 1) {
                        System.arraycopy(next.f2, 0, dArr, ((Integer) next.f1).intValue() * ReduceScatter.TRANSFER_BUFFER_SIZE, ReduceScatter.lastLen(i));
                    } else {
                        System.arraycopy(next.f2, 0, dArr, ((Integer) next.f1).intValue() * ReduceScatter.TRANSFER_BUFFER_SIZE, ReduceScatter.TRANSFER_BUFFER_SIZE);
                    }
                } while (it.hasNext());
                ReduceScatter.LOG.info("taskId: {}, ScatterRecv end", Integer.valueOf(indexOfThisSubtask));
            }
        }
    }

    public ReduceScatter(String str, String str2, String str3, AllReduce.SerializableBiConsumer<double[], double[]> serializableBiConsumer) {
        this.bufferName = str;
        this.op = serializableBiConsumer;
        this.recvbufName = str2;
        this.recvcntsName = str3;
    }

    @Override // com.alibaba.alink.common.comqueue.CommunicateFunction
    public <T> DataSet<T> communicateWith(DataSet<T> dataSet, int i) {
        return dataSet.mapPartition(new ReduceSend(this.bufferName, this.recvcntsName, UUID.randomUUID().toString(), i)).withBroadcastSet(dataSet, "barrier").returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.INT, FakeDoublePrimitiveArrayType.FAKE_DOUBLE_PRIMITIVE_ARRAY_TYPE})).name("ReduceSend").partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.tree.parallelcart.communication.ReduceScatter.2
            private static final long serialVersionUID = 3268524067347394279L;

            public int partition(Integer num, int i2) {
                return num.intValue();
            }
        }, 0).name("ReduceBroadcastRaw").mapPartition(new ReduceSum(this.recvcntsName, i, this.op)).returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.INT, FakeDoublePrimitiveArrayType.FAKE_DOUBLE_PRIMITIVE_ARRAY_TYPE})).name("ReduceSum").partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.tree.parallelcart.communication.ReduceScatter.1
            private static final long serialVersionUID = 4153050313317398157L;

            public int partition(Integer num, int i2) {
                return num.intValue();
            }
        }, 0).name("ReduceBroadcastSum").mapPartition(new ScatterRecv(this.recvbufName, this.recvcntsName, i)).returns(dataSet.getType()).name("ScatterRecv");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int pieces(int i) {
        int i2 = i / TRANSFER_BUFFER_SIZE;
        return i % TRANSFER_BUFFER_SIZE == 0 ? i2 : i2 + 1;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int lastLen(int i) {
        int i2 = i % TRANSFER_BUFFER_SIZE;
        return i2 == 0 ? TRANSFER_BUFFER_SIZE : i2;
    }
}
