package com.alibaba.alink.operator.common.dataproc;

import com.alibaba.alink.operator.common.dataproc.SortUtils;
import com.alibaba.alink.operator.common.nlp.WordCountUtil;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtilsNext.class */
public final class SortUtilsNext {
    private static final Logger LOG = LoggerFactory.getLogger(SortUtilsNext.class);
    private static final int SPLIT_POINT_SIZE = 1000;

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtilsNext$SampleSplitPoint.class */
    public static final class SampleSplitPoint<T> extends RichMapPartitionFunction<T, Tuple2<Object, Integer>> {
        private static final long serialVersionUID = -4623812104023341398L;
        private int taskId;
        private int cnt;

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.taskId = getRuntimeContext().getIndexOfThisSubtask();
            SortUtilsNext.LOG.info("{} open.", getRuntimeContext().getTaskName());
            Iterator it = getRuntimeContext().getBroadcastVariable(WordCountUtil.COUNT_COL_NAME).iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Tuple2 tuple2 = (Tuple2) it.next();
                if (((Integer) tuple2.f0).intValue() == this.taskId) {
                    this.cnt = ((Long) tuple2.f1).intValue();
                    break;
                }
            }
            SortUtilsNext.LOG.info("{} open end.", getRuntimeContext().getTaskName());
        }

        public void close() throws Exception {
            super.close();
            SortUtilsNext.LOG.info("{} close.", getRuntimeContext().getTaskName());
        }

        public void mapPartition(Iterable<T> iterable, Collector<Tuple2<Object, Integer>> collector) throws Exception {
            SortUtilsNext.LOG.info("{} mapPartition start.", getRuntimeContext().getTaskName());
            if (this.cnt <= 0) {
                collector.collect(new Tuple2(Integer.valueOf(getRuntimeContext().getNumberOfParallelSubtasks()), Integer.valueOf((-this.taskId) - 1)));
                return;
            }
            int min = Math.min(1000, this.cnt - 1);
            int i = 0;
            int i2 = 0;
            int genSampleIndex = (int) SortUtilsNext.genSampleIndex(0, this.cnt, min);
            for (T t : iterable) {
                if (i == genSampleIndex) {
                    collector.collect(Tuple2.of(t, Integer.valueOf(this.taskId)));
                    i2++;
                    genSampleIndex = (int) SortUtilsNext.genSampleIndex(i2, this.cnt, min);
                    if (genSampleIndex >= this.cnt) {
                        break;
                    }
                }
                i++;
            }
            collector.collect(new Tuple2(Integer.valueOf(getRuntimeContext().getNumberOfParallelSubtasks()), Integer.valueOf((-this.taskId) - 1)));
            SortUtilsNext.LOG.info("{} mapPartition end.", getRuntimeContext().getTaskName());
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtilsNext$SplitData.class */
    public static final class SplitData<T> extends RichMapPartitionFunction<T, Tuple2<Integer, T>> {
        private static final long serialVersionUID = -2108948299398213692L;
        private int taskId;
        private List<Tuple2<Object, Integer>> splitPoints;
        private Tuple2<Integer, T> outBuff;

        public void close() throws Exception {
            super.close();
            SortUtilsNext.LOG.info("{} close.", getRuntimeContext().getTaskName());
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            RuntimeContext runtimeContext = getRuntimeContext();
            this.taskId = runtimeContext.getIndexOfThisSubtask();
            this.splitPoints = (List) runtimeContext.getBroadcastVariableWithInitializer("splitPoints", new BroadcastVariableInitializer<Tuple2<Object, Integer>, List<Tuple2<Object, Integer>>>() { // from class: com.alibaba.alink.operator.common.dataproc.SortUtilsNext.SplitData.1
                public List<Tuple2<Object, Integer>> initializeBroadcastVariable(Iterable<Tuple2<Object, Integer>> iterable) {
                    ArrayList arrayList = new ArrayList();
                    Iterator<Tuple2<Object, Integer>> it = iterable.iterator();
                    while (it.hasNext()) {
                        arrayList.add(it.next());
                    }
                    arrayList.sort(new SortUtils.PairComparator());
                    return arrayList;
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m360initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Tuple2<Object, Integer>>) iterable);
                }
            });
            this.outBuff = new Tuple2<>();
            SortUtilsNext.LOG.info("{} open.", getRuntimeContext().getTaskName());
        }

        public void mapPartition(Iterable<T> iterable, Collector<Tuple2<Integer, T>> collector) throws Exception {
            if (this.splitPoints.isEmpty()) {
                Iterator<T> it = iterable.iterator();
                while (it.hasNext()) {
                    this.outBuff.setFields(0, it.next());
                    collector.collect(this.outBuff);
                }
                return;
            }
            int size = this.splitPoints.size();
            int i = 0;
            SortUtils.PairComparator pairComparator = new SortUtils.PairComparator();
            Tuple2<Object, Integer> of = Tuple2.of((Object) null, Integer.valueOf(this.taskId));
            for (T t : iterable) {
                if (i < size) {
                    of.f0 = t;
                    if (pairComparator.compare(of, this.splitPoints.get(i)) > 0) {
                        do {
                            i++;
                            if (i < size) {
                            }
                        } while (pairComparator.compare(of, this.splitPoints.get(i)) > 0);
                    }
                }
                this.outBuff.setFields(Integer.valueOf(i), t);
                collector.collect(this.outBuff);
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtilsNext$SplitPointReducer.class */
    public static class SplitPointReducer extends RichGroupReduceFunction<Tuple2<Object, Integer>, Tuple2<Object, Integer>> {
        private static final long serialVersionUID = 3507859541333351869L;

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            SortUtilsNext.LOG.info("{} open.", getRuntimeContext().getTaskName());
        }

        public void close() throws Exception {
            super.close();
            SortUtilsNext.LOG.info("{} close.", getRuntimeContext().getTaskName());
        }

        public void reduce(Iterable<Tuple2<Object, Integer>> iterable, Collector<Tuple2<Object, Integer>> collector) throws Exception {
            SortUtilsNext.LOG.info("{} reduce start.", getRuntimeContext().getTaskName());
            ArrayList arrayList = new ArrayList();
            int i = -1;
            for (Tuple2<Object, Integer> tuple2 : iterable) {
                if (((Integer) tuple2.f1).intValue() < 0) {
                    i = ((Integer) tuple2.f0).intValue();
                } else {
                    arrayList.add(Tuple2.of(tuple2.f0, tuple2.f1));
                }
            }
            if (arrayList.isEmpty()) {
                return;
            }
            int size = arrayList.size();
            arrayList.sort(new SortUtils.PairComparator());
            HashSet hashSet = new HashSet();
            int i2 = i - 1;
            for (int i3 = 0; i3 < i2; i3++) {
                int genSampleIndex = (int) SortUtilsNext.genSampleIndex(i3, size, i2);
                if (genSampleIndex >= size) {
                    throw new Exception("Index error. index: " + genSampleIndex + ". totalCount: " + size);
                }
                hashSet.add(arrayList.get(genSampleIndex));
            }
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                collector.collect((Tuple2) it.next());
            }
            SortUtilsNext.LOG.info("{} reduce end.", getRuntimeContext().getTaskName());
        }
    }

    public static <T extends Comparable<T>> Tuple2<DataSet<T>, DataSet<Tuple2<Integer, Long>>> pSort(DataSet<T> dataSet) {
        DataSet countElementsPerPartition = DataSetUtils.countElementsPerPartition(dataSet);
        SingleInputUdfOperator returns = dataSet.mapPartition(new RichMapPartitionFunction<T, T>() { // from class: com.alibaba.alink.operator.common.dataproc.SortUtilsNext.1
            private static final long serialVersionUID = -6035779320647427287L;
            int taskId;
            int cnt;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                this.taskId = getRuntimeContext().getIndexOfThisSubtask();
                SortUtilsNext.LOG.info("{} open.", getRuntimeContext().getTaskName());
                for (Tuple2 tuple2 : getRuntimeContext().getBroadcastVariable(WordCountUtil.COUNT_COL_NAME)) {
                    if (((Integer) tuple2.f0).intValue() == this.taskId) {
                        this.cnt = ((Long) tuple2.f1).intValue();
                        return;
                    }
                }
            }

            public void close() throws Exception {
                super.close();
                SortUtilsNext.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public void mapPartition(Iterable<T> iterable, Collector<T> collector) throws Exception {
                SortUtilsNext.LOG.info("{} map partition start.", getRuntimeContext().getTaskName());
                SortUtilsNext.LOG.info("{} map partition start.", Integer.valueOf(this.cnt));
                ArrayList arrayList = new ArrayList(this.cnt);
                SortUtilsNext.LOG.info("{} map sort all.", getRuntimeContext().getTaskName());
                Iterator<T> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add((Comparable) it.next());
                }
                SortUtilsNext.LOG.info("{} map sort start.", getRuntimeContext().getTaskName());
                arrayList.sort(Comparator.naturalOrder());
                SortUtilsNext.LOG.info("{} map sort end.", getRuntimeContext().getTaskName());
                Iterator it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    collector.collect((Comparable) it2.next());
                }
                SortUtilsNext.LOG.info("{} map partition end.", getRuntimeContext().getTaskName());
            }
        }).name("sorted").withBroadcastSet(countElementsPerPartition, WordCountUtil.COUNT_COL_NAME).returns(dataSet.getType());
        SingleInputUdfOperator returns2 = returns.mapPartition(new SplitData()).name("SplitData").withBroadcastSet(returns.mapPartition(new SampleSplitPoint()).name("SampleSplitPoint").withBroadcastSet(countElementsPerPartition, WordCountUtil.COUNT_COL_NAME).reduceGroup(new SplitPointReducer()).name("SplitPointReducer"), "splitPoints").returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, dataSet.getType()})).partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.dataproc.SortUtilsNext.3
            private static final long serialVersionUID = 8886451959375466262L;

            public int partition(Integer num, int i) {
                return num.intValue() % i;
            }
        }, 0).name("partitionCustom").map(new MapFunction<Tuple2<Integer, T>, T>() { // from class: com.alibaba.alink.operator.common.dataproc.SortUtilsNext.2
            private static final long serialVersionUID = -7579866556587630212L;

            /* JADX WARN: Incorrect return type in method signature: (Lorg/apache/flink/api/java/tuple/Tuple2<Ljava/lang/Integer;TT;>;)TT; */
            public Comparable map(Tuple2 tuple2) throws Exception {
                return (Comparable) tuple2.f1;
            }
        }).name("partitioned").returns(dataSet.getType());
        return Tuple2.of(returns2, DataSetUtils.countElementsPerPartition(returns2));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static long genSampleIndex(long j, long j2, long j3) {
        long j4 = j + 1;
        long j5 = j3 + 1;
        return (((j2 / j5) * j4) + Math.min(j2 % j5, j4)) - 1;
    }
}
