package com.alibaba.alink.operator.batch.dataproc;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkIllegalStateException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.utils.RowUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.dataproc.SplitParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.shaded.guava18.com.google.common.hash.HashFunction;
import org.apache.flink.shaded.guava18.com.google.common.hash.Hashing;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)})
@NameCn("数据拆分")
@NameEn("Data Splitting")
/* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/SplitBatchOp.class */
public final class SplitBatchOp extends BatchOperator<SplitBatchOp> implements SplitParams<SplitBatchOp> {
    private static final long serialVersionUID = -1436970192619749693L;
    private static final Logger LOG = LoggerFactory.getLogger(SplitBatchOp.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/SplitBatchOp$CountInPartition.class */
    public static class CountInPartition extends RichMapPartitionFunction<Tuple2<Integer, Long>, long[]> {
        private static final long serialVersionUID = 7797942238612563554L;
        private double fraction;

        public CountInPartition(double d) {
            this.fraction = d;
        }

        public void mapPartition(Iterable<Tuple2<Integer, Long>> iterable, Collector<long[]> collector) throws Exception {
            int i;
            AkPreconditions.checkArgument(getRuntimeContext().getIndexOfThisSubtask() == 0, "The index of this task is not zero, but " + getRuntimeContext().getIndexOfThisSubtask());
            long j = 0;
            ArrayList<Tuple2> arrayList = new ArrayList();
            for (Tuple2<Integer, Long> tuple2 : iterable) {
                j += ((Long) tuple2.f1).longValue();
                arrayList.add(tuple2);
            }
            int size = arrayList.size();
            long[] jArr = new long[size];
            long round = Math.round(j * this.fraction);
            long[] jArr2 = new long[size];
            for (Tuple2 tuple22 : arrayList) {
                jArr[((Integer) tuple22.f0).intValue()] = ((Long) tuple22.f1).longValue();
            }
            long j2 = 0;
            for (int i2 = 0; i2 < size; i2++) {
                jArr2[i2] = Math.round(Math.floor(jArr[i2] * this.fraction));
                j2 += jArr2[i2];
            }
            if (j2 < round) {
                long min = Math.min(round - j2, j - j2);
                if (min == j - j2) {
                    for (int i3 = 0; i3 < size; i3++) {
                        jArr2[i3] = jArr[i3];
                    }
                } else {
                    ArrayList arrayList2 = new ArrayList(size);
                    while (min > 0) {
                        for (int i4 = 0; i4 < size; i4++) {
                            arrayList2.add(Integer.valueOf(i4));
                        }
                        Collections.shuffle(arrayList2, new Random(0L));
                        for (int i5 = 0; i5 < Math.min(min, size); i5++) {
                            int intValue = ((Integer) arrayList2.get(i5)).intValue();
                            while (true) {
                                i = intValue;
                                if (jArr2[i] >= jArr[i]) {
                                    intValue = (i + 1) % size;
                                }
                            }
                            jArr2[i] = jArr2[i] + 1;
                        }
                        min -= size;
                    }
                }
            }
            long[] jArr3 = new long[size * 2];
            for (int i6 = 0; i6 < size; i6++) {
                jArr3[i6] = jArr[i6];
                jArr3[i6 + size] = jArr2[i6];
            }
            collector.collect(jArr3);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/SplitBatchOp$PickInPartition.class */
    public static class PickInPartition extends RichMapPartitionFunction<Row, Tuple2<Boolean, Row>> {
        private static final long serialVersionUID = 2835501123999397324L;

        private PickInPartition() {
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Boolean, Row>> collector) throws Exception {
            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            List broadcastVariable = getRuntimeContext().getBroadcastVariable("counts");
            long[] copyOfRange = Arrays.copyOfRange((long[]) broadcastVariable.get(0), 0, numberOfParallelSubtasks);
            long[] copyOfRange2 = Arrays.copyOfRange((long[]) broadcastVariable.get(0), numberOfParallelSubtasks, numberOfParallelSubtasks * 2);
            if (((long[]) broadcastVariable.get(0)).length / 2 != getRuntimeContext().getNumberOfParallelSubtasks()) {
                throw new AkIllegalStateException("parallelism has changed");
            }
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            int[] iArr = null;
            int i = 0;
            int i2 = 0;
            for (Row row : iterable) {
                if (0 == i) {
                    long j = copyOfRange[indexOfThisSubtask];
                    long j2 = copyOfRange2[indexOfThisSubtask];
                    ArrayList arrayList = new ArrayList((int) j);
                    for (int i3 = 0; i3 < j; i3++) {
                        arrayList.add(Integer.valueOf(i3));
                    }
                    Collections.shuffle(arrayList, new Random(indexOfThisSubtask));
                    iArr = new int[(int) j2];
                    for (int i4 = 0; i4 < j2; i4++) {
                        iArr[i4] = ((Integer) arrayList.get(i4)).intValue();
                    }
                    Arrays.sort(iArr);
                }
                if (i2 >= iArr.length || i != iArr[i2]) {
                    collector.collect(Tuple2.of(false, row));
                } else {
                    collector.collect(Tuple2.of(true, row));
                    i2++;
                }
                i++;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/SplitBatchOp$PickInPartitionWithSeed.class */
    public static class PickInPartitionWithSeed extends RichMapPartitionFunction<Tuple2<Long, Row>, Tuple2<Boolean, Row>> {
        private static final long serialVersionUID = 2835501123999397324L;
        private int seed;

        public PickInPartitionWithSeed(int i) {
            this.seed = i;
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void mapPartition(Iterable<Tuple2<Long, Row>> iterable, Collector<Tuple2<Boolean, Row>> collector) throws Exception {
            long[] jArr = (long[]) getRuntimeContext().getBroadcastVariable("counts").get(0);
            SplitBatchOp.LOG.info(Arrays.toString(jArr));
            long[] copyOfRange = Arrays.copyOfRange(jArr, 0, jArr.length / 2);
            long[] copyOfRange2 = Arrays.copyOfRange(jArr, jArr.length / 2, jArr.length);
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (jArr.length / 2 != getRuntimeContext().getNumberOfParallelSubtasks()) {
                throw new AkIllegalStateException("parallelism has changed");
            }
            int i = (int) copyOfRange[indexOfThisSubtask];
            int i2 = (int) copyOfRange2[indexOfThisSubtask];
            ArrayList arrayList = new ArrayList();
            for (int i3 = 0; i3 < i; i3++) {
                arrayList.add(Integer.valueOf(i3));
            }
            Collections.shuffle(arrayList, new Random(this.seed));
            List subList = arrayList.subList(0, i2);
            Collections.sort(subList);
            Comparator<Row> comparator = new Comparator<Row>() { // from class: com.alibaba.alink.operator.batch.dataproc.SplitBatchOp.PickInPartitionWithSeed.1
                @Override // java.util.Comparator
                public int compare(Row row, Row row2) {
                    for (int i4 = 0; i4 < row.getArity(); i4++) {
                        int compareTo = row.getField(i4) instanceof Comparable ? ((Comparable) row.getField(i4)).compareTo(row2.getField(i4)) : row.getField(i4).toString().compareTo(row2.getField(i4).toString());
                        if (compareTo != 0) {
                            return compareTo;
                        }
                    }
                    return 0;
                }
            };
            int i4 = 0;
            Tuple2<Long, Row> tuple2 = null;
            int i5 = 0;
            Integer num = i2 == 0 ? null : (Integer) subList.get(0);
            for (Tuple2<Long, Row> tuple22 : iterable) {
                if (null == num) {
                    collector.collect(Tuple2.of(false, tuple22.f1));
                    i4++;
                } else {
                    if (tuple2 == null || ((Long) tuple2.f0).longValue() < ((Long) tuple22.f0).longValue() || comparator.compare(tuple2.f1, tuple22.f1) <= 0) {
                        if (i4 - 1 == num.intValue()) {
                            collector.collect(Tuple2.of(true, tuple2.f1));
                            i5++;
                            num = i5 >= i2 ? null : (Integer) subList.get(i5);
                        } else if (tuple2 != null) {
                            collector.collect(Tuple2.of(false, tuple2.f1));
                        }
                        tuple2 = tuple22;
                    } else {
                        if (((Long) tuple2.f0).longValue() > ((Long) tuple22.f0).longValue()) {
                            throw new AkUnclassifiedErrorException("Order error!");
                        }
                        if (i4 - 1 == num.intValue()) {
                            collector.collect(Tuple2.of(true, tuple22.f1));
                            i5++;
                            num = i5 >= i2 ? null : (Integer) subList.get(i5);
                        } else {
                            collector.collect(Tuple2.of(false, tuple22.f1));
                        }
                    }
                    i4++;
                }
            }
            AkPreconditions.checkArgument(i4 == i, "Group value not equal to count value!");
            if (null != num) {
                AkPreconditions.checkArgument(i4 - 1 == num.intValue() && i5 + 1 == i2, "Inner error, select number not equal to index!");
                collector.collect(Tuple2.of(true, tuple2.f1));
            } else if (null != tuple2) {
                collector.collect(Tuple2.of(false, tuple2.f1));
            }
        }
    }

    public SplitBatchOp() {
        this(new Params());
    }

    public SplitBatchOp(Params params) {
        super(params);
    }

    public SplitBatchOp(double d) {
        this(new Params().set((ParamInfo<ParamInfo<Double>>) FRACTION, (ParamInfo<Double>) Double.valueOf(d)));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public SplitBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        Operator name;
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        double doubleValue = getFraction().doubleValue();
        if (doubleValue < Criteria.INVALID_GAIN || doubleValue > 1.0d) {
            throw new AkIllegalOperatorParameterException("invalid fraction " + doubleValue);
        }
        DataSet<Row> dataSet = checkAndGetFirst.getDataSet();
        Integer randomSeed = getRandomSeed();
        if (null != randomSeed) {
            final HashFunction murmur3_128 = Hashing.murmur3_128(randomSeed.intValue());
            PartitionOperator partitionCustom = dataSet.map(new RichMapFunction<Row, Tuple2<Long, Row>>() { // from class: com.alibaba.alink.operator.batch.dataproc.SplitBatchOp.2
                private static final long serialVersionUID = -287601103797809499L;

                public Tuple2<Long, Row> map(Row row) {
                    return Tuple2.of(Long.valueOf(murmur3_128.hashUnencodedChars(RowUtil.rowToString(row)).asLong()), row);
                }
            }).partitionCustom(new Partitioner<Long>() { // from class: com.alibaba.alink.operator.batch.dataproc.SplitBatchOp.1
                private static final long serialVersionUID = 871499283994717282L;

                public int partition(Long l, int i) {
                    return Math.abs(l.intValue()) % i;
                }
            }, 0);
            name = partitionCustom.sortPartition(0, Order.ASCENDING).mapPartition(new PickInPartitionWithSeed(randomSeed.intValue())).withBroadcastSet(DataSetUtils.countElementsPerPartition(partitionCustom).mapPartition(new CountInPartition(doubleValue)).setParallelism(1).name("decide_count_of_each_partition"), "counts").name("pick_in_each_partition");
        } else {
            name = dataSet.mapPartition(new PickInPartition()).withBroadcastSet(DataSetUtils.countElementsPerPartition(dataSet).mapPartition(new CountInPartition(doubleValue)).setParallelism(1).name("decide_count_of_each_partition"), "counts").name("pick_in_each_partition");
        }
        FlatMapOperator flatMap = name.flatMap(new FlatMapFunction<Tuple2<Boolean, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.dataproc.SplitBatchOp.3
            private static final long serialVersionUID = -1015919192379666607L;

            public void flatMap(Tuple2<Boolean, Row> tuple2, Collector<Row> collector) {
                if (((Boolean) tuple2.f0).booleanValue()) {
                    collector.collect(tuple2.f1);
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple2<Boolean, Row>) obj, (Collector<Row>) collector);
            }
        });
        FlatMapOperator flatMap2 = name.flatMap(new FlatMapFunction<Tuple2<Boolean, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.dataproc.SplitBatchOp.4
            private static final long serialVersionUID = -7288487577579174535L;

            public void flatMap(Tuple2<Boolean, Row> tuple2, Collector<Row> collector) {
                if (((Boolean) tuple2.f0).booleanValue()) {
                    return;
                }
                collector.collect(tuple2.f1);
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple2<Boolean, Row>) obj, (Collector<Row>) collector);
            }
        });
        setOutput((DataSet<Row>) flatMap, checkAndGetFirst.getSchema());
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) flatMap2, checkAndGetFirst.getSchema())});
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ SplitBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
