package com.alibaba.alink.operator.common.dataproc;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtils.class */
public class SortUtils {
    public static final ComparableComparator OBJECT_COMPARATOR = new ComparableComparator();
    private static final Logger LOG = LoggerFactory.getLogger(SortUtils.class);
    public static final int SPLIT_POINT_SIZE = 1000;

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtils$AvgLongPartitioner.class */
    public static class AvgLongPartitioner implements Partitioner<Long> {
        private static final long serialVersionUID = -4797639155425333832L;

        public int partition(Long l, int i) {
            return (int) (l.longValue() % i);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtils$AvgPartition.class */
    public static class AvgPartition implements Partitioner<Integer> {
        private static final long serialVersionUID = 7926524547138192316L;

        public int partition(Integer num, int i) {
            return num.intValue() % i;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtils$ComparableComparator.class */
    public static class ComparableComparator implements Comparator<Object> {
        @Override // java.util.Comparator
        public int compare(Object obj, Object obj2) {
            if (obj == null && obj2 == null) {
                return 0;
            }
            if (obj == null) {
                return 1;
            }
            if (obj2 == null) {
                return -1;
            }
            return ((Comparable) obj).compareTo((Comparable) obj2);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtils$PairComparator.class */
    public static class PairComparator implements Comparator<Tuple2<Object, Integer>> {
        ComparableComparator objectComparator = new ComparableComparator();

        @Override // java.util.Comparator
        public int compare(Tuple2<Object, Integer> tuple2, Tuple2<Object, Integer> tuple22) {
            int compare = this.objectComparator.compare(tuple2.f0, tuple22.f0);
            return compare == 0 ? ((Integer) tuple2.f1).compareTo((Integer) tuple22.f1) : compare;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtils$RowComparator.class */
    public static class RowComparator implements Comparator<Row> {
        private final ComparableComparator objectComparator = new ComparableComparator();
        private final int index;

        public RowComparator(int i) {
            this.index = i;
        }

        @Override // java.util.Comparator
        public int compare(Row row, Row row2) {
            return this.objectComparator.compare(row.getField(this.index), row2.getField(this.index));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtils$SampleSplitPoint.class */
    public static class SampleSplitPoint extends RichMapPartitionFunction<Row, Tuple2<Object, Integer>> {
        private static final long serialVersionUID = 5794681315388872420L;
        private int taskId;
        private final int index;

        public SampleSplitPoint(int i) {
            this.index = i;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.taskId = getRuntimeContext().getIndexOfThisSubtask();
            SortUtils.LOG.info("{} open.", getRuntimeContext().getTaskName());
        }

        public void close() throws Exception {
            super.close();
            SortUtils.LOG.info("{} close.", getRuntimeContext().getTaskName());
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Object, Integer>> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            Iterator<Row> it = iterable.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().getField(this.index));
            }
            if (arrayList.isEmpty()) {
                return;
            }
            arrayList.sort(new ComparableComparator());
            int size = arrayList.size();
            int min = Math.min(SortUtils.SPLIT_POINT_SIZE, size - 1);
            ArrayList arrayList2 = new ArrayList(min);
            for (int i = 0; i < min; i++) {
                int intValue = SortUtils.genSampleIndex(Long.valueOf(i), Long.valueOf(size), Long.valueOf(min)).intValue();
                if (intValue >= size) {
                    throw new Exception("Index error. index: " + intValue + ". totalCount: " + size);
                }
                arrayList2.add(arrayList.get(intValue));
            }
            Iterator it2 = arrayList2.iterator();
            while (it2.hasNext()) {
                collector.collect(Tuple2.of(it2.next(), Integer.valueOf(this.taskId)));
            }
            collector.collect(Tuple2.of(Integer.valueOf(getRuntimeContext().getNumberOfParallelSubtasks()), Integer.valueOf((-this.taskId) - 1)));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtils$SplitData.class */
    public static class SplitData extends RichMapPartitionFunction<Row, Tuple2<Integer, Row>> {
        private static final long serialVersionUID = -8110878379231333376L;
        private int taskId;
        private List<Tuple2<Object, Integer>> splitPoints;
        private final int index;

        public SplitData(int i) {
            this.index = i;
        }

        public void close() throws Exception {
            super.close();
            SortUtils.LOG.info("{} close.", getRuntimeContext().getTaskName());
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            RuntimeContext runtimeContext = getRuntimeContext();
            this.taskId = runtimeContext.getIndexOfThisSubtask();
            this.splitPoints = (List) runtimeContext.getBroadcastVariableWithInitializer("splitPoints", new BroadcastVariableInitializer<Tuple2<Object, Integer>, List<Tuple2<Object, Integer>>>() { // from class: com.alibaba.alink.operator.common.dataproc.SortUtils.SplitData.1
                public List<Tuple2<Object, Integer>> initializeBroadcastVariable(Iterable<Tuple2<Object, Integer>> iterable) {
                    ArrayList arrayList = new ArrayList();
                    Iterator<Tuple2<Object, Integer>> it = iterable.iterator();
                    while (it.hasNext()) {
                        arrayList.add(it.next());
                    }
                    arrayList.sort(new PairComparator());
                    return arrayList;
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m358initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Tuple2<Object, Integer>>) iterable);
                }
            });
            SortUtils.LOG.info("{} open.", getRuntimeContext().getTaskName());
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Integer, Row>> collector) throws Exception {
            if (this.splitPoints.isEmpty()) {
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    collector.collect(new Tuple2(0, it.next()));
                }
                return;
            }
            for (Row row : iterable) {
                Tuple2 tuple2 = new Tuple2((Object) null, Integer.valueOf(this.taskId));
                tuple2.f0 = row.getField(this.index);
                int binarySearch = Collections.binarySearch(this.splitPoints, tuple2, new PairComparator());
                collector.collect(new Tuple2(Integer.valueOf(binarySearch >= 0 ? binarySearch : (-binarySearch) - 1), row));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/SortUtils$SplitPointReducer.class */
    public static class SplitPointReducer extends RichGroupReduceFunction<Tuple2<Object, Integer>, Tuple2<Object, Integer>> {
        private static final long serialVersionUID = 2757032418456975057L;

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            SortUtils.LOG.info("{} open.", getRuntimeContext().getTaskName());
        }

        public void close() throws Exception {
            super.close();
            SortUtils.LOG.info("{} close.", getRuntimeContext().getTaskName());
        }

        public void reduce(Iterable<Tuple2<Object, Integer>> iterable, Collector<Tuple2<Object, Integer>> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            int i = -1;
            for (Tuple2<Object, Integer> tuple2 : iterable) {
                if (((Integer) tuple2.f1).intValue() < 0) {
                    i = ((Integer) tuple2.f0).intValue();
                } else {
                    arrayList.add(new Tuple2(tuple2.f0, tuple2.f1));
                }
            }
            if (arrayList.isEmpty()) {
                return;
            }
            int size = arrayList.size();
            arrayList.sort(new PairComparator());
            HashSet hashSet = new HashSet();
            int i2 = i - 1;
            for (int i3 = 0; i3 < i2; i3++) {
                int intValue = SortUtils.genSampleIndex(Long.valueOf(i3), Long.valueOf(size), Long.valueOf(i2)).intValue();
                if (intValue >= size) {
                    throw new Exception("Index error. index: " + intValue + ". totalCount: " + size);
                }
                hashSet.add(arrayList.get(intValue));
            }
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                collector.collect((Tuple2) it.next());
            }
        }
    }

    public static Tuple2<DataSet<Tuple2<Integer, Row>>, DataSet<Tuple2<Integer, Long>>> pSort(DataSet<Row> dataSet, int i) {
        SingleInputUdfOperator withBroadcastSet = dataSet.mapPartition(new SplitData(i)).withBroadcastSet(dataSet.mapPartition(new SampleSplitPoint(i)).reduceGroup(new SplitPointReducer()), "splitPoints");
        return new Tuple2<>(withBroadcastSet, withBroadcastSet.groupBy(new int[]{0}).withPartitioner(new AvgPartition()).reduceGroup(new RichGroupReduceFunction<Tuple2<Integer, Row>, Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.common.dataproc.SortUtils.1
            private static final long serialVersionUID = 819186091051991112L;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                SortUtils.LOG.info("{} open.", getRuntimeContext().getTaskName());
            }

            public void close() throws Exception {
                super.close();
                SortUtils.LOG.info("{} close.", getRuntimeContext().getTaskName());
            }

            public void reduce(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple2<Integer, Long>> collector) {
                Integer num = -1;
                Long l = 0L;
                Iterator<Tuple2<Integer, Row>> it = iterable.iterator();
                while (it.hasNext()) {
                    num = (Integer) it.next().f0;
                    l = Long.valueOf(l.longValue() + 1);
                }
                collector.collect(new Tuple2(num, l));
            }
        }));
    }

    public static Long genSampleIndex(Long l, Long l2, Long l3) {
        Long valueOf = Long.valueOf(l.longValue() + 1);
        Long valueOf2 = Long.valueOf(l3.longValue() + 1);
        Long valueOf3 = Long.valueOf(l2.longValue() / valueOf2.longValue());
        long longValue = l2.longValue() % valueOf2.longValue();
        return Long.valueOf(((valueOf3.longValue() * valueOf.longValue()) + (longValue > valueOf.longValue() ? valueOf.longValue() : longValue)) - 1);
    }
}
