package com.alibaba.alink.operator.batch.dataproc;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.dataproc.WeightSampleParams;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.DATA)})
@ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})
@NameCn("加权采样")
@NameEn("Weighted Sampling")
/* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/WeightSampleBatchOp.class */
public class WeightSampleBatchOp extends BatchOperator<WeightSampleBatchOp> implements WeightSampleParams<WeightSampleBatchOp> {
    private static final long serialVersionUID = 8815784097940967758L;
    private static String COUNT = DbscanConstant.COUNT;
    private static String BOUNDS = "bounds";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/WeightSampleBatchOp$RandomSelect.class */
    public static class RandomSelect extends RichMapPartitionFunction<Row, Row> {
        private static final long serialVersionUID = 5592394863599823024L;
        private Random random;
        private List<Double> cuts = new ArrayList();
        private int weightIdx;
        private double ratio;

        public RandomSelect(Random random, double d, int i) {
            this.random = random;
            this.ratio = d;
            this.weightIdx = i;
        }

        public void open(Configuration configuration) {
            Tuple2 tuple2 = (Tuple2) getRuntimeContext().getBroadcastVariable(WeightSampleBatchOp.BOUNDS).get(0);
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            double d = ((double[]) tuple2.f1)[indexOfThisSubtask];
            double d2 = ((double[]) tuple2.f1)[indexOfThisSubtask + 1];
            double d3 = ((double[]) tuple2.f1)[((double[]) tuple2.f1).length - 1];
            double d4 = Double.compare(d2, d3) == 0 ? d2 + 0.1d : d2;
            int intValue = (int) (((Integer) tuple2.f0).intValue() * this.ratio);
            for (int i = 0; i < intValue; i++) {
                double nextDouble = this.random.nextDouble() * d3;
                if (nextDouble >= d && nextDouble < d4) {
                    this.cuts.add(Double.valueOf(nextDouble - d));
                }
            }
            Collections.sort(this.cuts);
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) {
            if (this.cuts.size() == 0) {
                return;
            }
            double d = 0.0d;
            int i = 0 + 1;
            double doubleValue = this.cuts.get(0).doubleValue();
            for (Row row : iterable) {
                double doubleValue2 = ((Number) row.getField(this.weightIdx)).doubleValue();
                while (d + doubleValue2 > doubleValue) {
                    collector.collect(row);
                    if (i >= this.cuts.size()) {
                        return;
                    }
                    int i2 = i;
                    i++;
                    doubleValue = this.cuts.get(i2).doubleValue();
                }
                d += doubleValue2;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/WeightSampleBatchOp$TopNSelect.class */
    public static class TopNSelect extends RichMapPartitionFunction<Tuple2<Double, Row>, Row> {
        private static final long serialVersionUID = -461361457193125904L;
        private double ratio;
        private List<Tuple3<Integer, Integer, Double>> list;

        public TopNSelect(double d) {
            this.ratio = d;
        }

        public void open(Configuration configuration) {
            List broadcastVariable = getRuntimeContext().getBroadcastVariable(WeightSampleBatchOp.COUNT);
            this.list = new ArrayList();
            List<Tuple3<Integer, Integer, Double>> list = this.list;
            list.getClass();
            broadcastVariable.forEach((v1) -> {
                r1.add(v1);
            });
            Collections.sort(this.list, Comparator.comparingDouble(tuple3 -> {
                return -((Double) tuple3.f2).doubleValue();
            }));
        }

        public void mapPartition(Iterable<Tuple2<Double, Row>> iterable, Collector<Row> collector) {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            int i = 0;
            int i2 = 0;
            int i3 = 0;
            for (Tuple3<Integer, Integer, Double> tuple3 : this.list) {
                if (((Integer) tuple3.f0).equals(Integer.valueOf(indexOfThisSubtask))) {
                    i = i3;
                    i2 = i3 + ((Integer) tuple3.f1).intValue();
                }
                i3 += ((Integer) tuple3.f1).intValue();
            }
            int i4 = (int) (i3 * this.ratio);
            if (i >= i4) {
                return;
            }
            if (i2 < i4) {
                iterable.forEach(tuple2 -> {
                    collector.collect(tuple2.f1);
                });
                return;
            }
            int i5 = i4 - i;
            PriorityQueue priorityQueue = new PriorityQueue(Comparator.comparingDouble(tuple22 -> {
                return ((Double) tuple22.f0).doubleValue();
            }));
            double d = Double.MIN_VALUE;
            for (Tuple2<Double, Row> tuple23 : iterable) {
                if (priorityQueue.size() < i5) {
                    priorityQueue.add(tuple23);
                    d = ((Double) ((Tuple2) priorityQueue.peek()).f0).doubleValue();
                } else if (((Double) tuple23.f0).doubleValue() > d) {
                    priorityQueue.poll();
                    priorityQueue.add(tuple23);
                    d = ((Double) ((Tuple2) priorityQueue.peek()).f0).doubleValue();
                }
            }
            priorityQueue.forEach(tuple24 -> {
                collector.collect(tuple24.f1);
            });
        }
    }

    public WeightSampleBatchOp() {
        this(null);
    }

    public WeightSampleBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public WeightSampleBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        DataSet<Row> dataSet = checkAndGetFirst.getDataSet();
        final int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), getWeightCol());
        double doubleValue = getRatio().doubleValue();
        if (getWithReplacement().booleanValue()) {
            setOutput((DataSet<Row>) dataSet.mapPartition(new RandomSelect(new Random(0L), doubleValue, findColIndexWithAssertAndHint)).withBroadcastSet(dataSet.mapPartition(new RichMapPartitionFunction<Row, Tuple3<Integer, Integer, Double>>() { // from class: com.alibaba.alink.operator.batch.dataproc.WeightSampleBatchOp.1
                private static final long serialVersionUID = -684553157530047702L;

                public void mapPartition(Iterable<Row> iterable, Collector<Tuple3<Integer, Integer, Double>> collector) {
                    int i = 0;
                    double d = 0.0d;
                    int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                    Iterator<Row> it = iterable.iterator();
                    while (it.hasNext()) {
                        double doubleValue2 = ((Number) it.next().getField(findColIndexWithAssertAndHint)).doubleValue();
                        AkPreconditions.checkArgument(doubleValue2 > Criteria.INVALID_GAIN && !Double.isNaN(doubleValue2) && Double.isFinite(doubleValue2), (ExceptionWithErrorCode) new AkIllegalDataException("Weight must be positive!"));
                        i++;
                        d += doubleValue2;
                    }
                    collector.collect(Tuple3.of(Integer.valueOf(indexOfThisSubtask), Integer.valueOf(i), Double.valueOf(d)));
                }
            }).reduceGroup(new GroupReduceFunction<Tuple3<Integer, Integer, Double>, Tuple2<Integer, double[]>>() { // from class: com.alibaba.alink.operator.batch.dataproc.WeightSampleBatchOp.2
                private static final long serialVersionUID = 2912858605429940900L;

                public void reduce(Iterable<Tuple3<Integer, Integer, Double>> iterable, Collector<Tuple2<Integer, double[]>> collector) {
                    ArrayList arrayList = new ArrayList();
                    arrayList.getClass();
                    iterable.forEach((v1) -> {
                        r1.add(v1);
                    });
                    Collections.sort(arrayList, Comparator.comparingDouble(tuple3 -> {
                        return ((Integer) tuple3.f0).intValue();
                    }));
                    double[] dArr = new double[arrayList.size() + 1];
                    int i = 0;
                    for (int i2 = 0; i2 < arrayList.size(); i2++) {
                        dArr[i2 + 1] = dArr[i2] + ((Double) ((Tuple3) arrayList.get(i2)).f2).doubleValue();
                        i += ((Integer) ((Tuple3) arrayList.get(i2)).f1).intValue();
                    }
                    collector.collect(Tuple2.of(Integer.valueOf(i), dArr));
                }
            }), BOUNDS), checkAndGetFirst.getSchema());
        } else {
            PartitionOperator partitionByRange = dataSet.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Double, Row>>() { // from class: com.alibaba.alink.operator.batch.dataproc.WeightSampleBatchOp.3
                private static final long serialVersionUID = -9150449993114999173L;

                public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Double, Row>> collector) {
                    double d;
                    Random random = new Random(getRuntimeContext().getIndexOfThisSubtask());
                    for (Row row : iterable) {
                        double doubleValue2 = ((Number) row.getField(findColIndexWithAssertAndHint)).doubleValue();
                        AkPreconditions.checkArgument(doubleValue2 > Criteria.INVALID_GAIN && !Double.isNaN(doubleValue2) && Double.isFinite(doubleValue2), (ExceptionWithErrorCode) new AkIllegalDataException("Weight must be positive!"));
                        double nextDouble = random.nextDouble();
                        while (true) {
                            d = nextDouble;
                            if (d <= 1.0E-30d) {
                                nextDouble = random.nextDouble();
                            }
                        }
                        collector.collect(Tuple2.of(Double.valueOf(Math.log(d) / doubleValue2), row));
                    }
                }
            }).partitionByRange(new int[]{0});
            setOutput((DataSet<Row>) partitionByRange.mapPartition(new TopNSelect(doubleValue)).withBroadcastSet(partitionByRange.mapPartition(new RichMapPartitionFunction<Tuple2<Double, Row>, Tuple3<Integer, Integer, Double>>() { // from class: com.alibaba.alink.operator.batch.dataproc.WeightSampleBatchOp.4
                private static final long serialVersionUID = -281138469922874075L;

                public void mapPartition(Iterable<Tuple2<Double, Row>> iterable, Collector<Tuple3<Integer, Integer, Double>> collector) {
                    int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                    int i = 0;
                    double d = Double.MAX_VALUE;
                    Iterator<Tuple2<Double, Row>> it = iterable.iterator();
                    while (it.hasNext()) {
                        d = Math.min(((Double) it.next().f0).doubleValue(), d);
                        i++;
                    }
                    collector.collect(Tuple3.of(Integer.valueOf(indexOfThisSubtask), Integer.valueOf(i), Double.valueOf(d)));
                }
            }), COUNT), checkAndGetFirst.getSchema());
        }
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ WeightSampleBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
