package com.alibaba.alink.operator.common.dataproc;

import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichCoGroupFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/BlockwiseCross.class */
public class BlockwiseCross implements Serializable {
    private static final Logger LOG = LoggerFactory.getLogger(BlockwiseCross.class);
    private static final long serialVersionUID = -3156041531337016663L;

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/BlockwiseCross$BulkScoreFunction.class */
    public interface BulkScoreFunction<T1, T2> extends Function, Serializable {
        void addTargets(Iterable<Tuple3<Integer, Long, T2>> iterable);

        List<Tuple2<Long, Float>> scoreAll(Long l, T1 t1);
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/BlockwiseCross$DefaultBulkScoreFunction.class */
    private static class DefaultBulkScoreFunction<T1, T2> implements BulkScoreFunction<T1, T2> {
        private static final long serialVersionUID = -3256840558203304120L;
        private transient List<Tuple2<Long, T2>> targets;
        private transient List<Tuple2<Long, Float>> scoreBuffer;
        private ScoreFunction<T1, T2> scoreFunction;

        DefaultBulkScoreFunction(ScoreFunction<T1, T2> scoreFunction) {
            this.scoreFunction = scoreFunction;
        }

        @Override // com.alibaba.alink.operator.common.dataproc.BlockwiseCross.BulkScoreFunction
        public void addTargets(Iterable<Tuple3<Integer, Long, T2>> iterable) {
            this.targets = new ArrayList();
            this.scoreBuffer = new ArrayList();
            iterable.forEach(tuple3 -> {
                this.targets.add(Tuple2.of(tuple3.f1, tuple3.f2));
                this.scoreBuffer.add(Tuple2.of(tuple3.f1, Float.valueOf(0.0f)));
            });
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // com.alibaba.alink.operator.common.dataproc.BlockwiseCross.BulkScoreFunction
        public List<Tuple2<Long, Float>> scoreAll(Long l, T1 t1) {
            for (int i = 0; i < this.targets.size(); i++) {
                this.scoreBuffer.get(i).setFields(this.targets.get(i).f0, Float.valueOf(this.scoreFunction.score(l, t1, (Long) this.targets.get(i).f0, this.targets.get(i).f1)));
            }
            return this.scoreBuffer;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/BlockwiseCross$ScoreFunction.class */
    public interface ScoreFunction<T1, T2> extends Function, Serializable {
        float score(Long l, T1 t1, Long l2, T2 t2);
    }

    public static <T1, T2> DataSet<Tuple3<Long, long[], float[]>> findTopK(DataSet<Tuple2<Long, T1>> dataSet, DataSet<Tuple2<Long, T2>> dataSet2, int i, Order order, ScoreFunction<T1, T2> scoreFunction) {
        return findTopK(dataSet, dataSet2, i, order, new DefaultBulkScoreFunction(scoreFunction));
    }

    /* JADX WARN: Type inference failed for: r5v3, types: [com.alibaba.alink.operator.common.dataproc.BlockwiseCross$1] */
    public static <T1, T2> DataSet<Tuple3<Long, long[], float[]>> findTopK(DataSet<Tuple2<Long, T1>> dataSet, DataSet<Tuple2<Long, T2>> dataSet2, final int i, final Order order, final BulkScoreFunction<T1, T2> bulkScoreFunction) {
        PartitionOperator rebalance = dataSet.rebalance();
        PartitionOperator rebalance2 = dataSet2.rebalance();
        final int parallelism = rebalance.getExecutionEnvironment().getParallelism();
        DataSet appendTaskId = appendTaskId(rebalance);
        DataSet appendTaskId2 = appendTaskId(rebalance2);
        TupleTypeInfo tupleTypeInfo = new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.LONG, rebalance.getType().getTypeAt(1), new TypeHint<PriorityQueue<Tuple2<Long, Float>>>() { // from class: com.alibaba.alink.operator.common.dataproc.BlockwiseCross.1
        }.getTypeInfo()});
        IterativeDataSet iterate = appendTaskId.map(new RichMapFunction<Tuple3<Integer, Long, T1>, Tuple4<Integer, Long, T1, PriorityQueue<Tuple2<Long, Float>>>>() { // from class: com.alibaba.alink.operator.common.dataproc.BlockwiseCross.2
            private static final long serialVersionUID = -4852464977864718718L;

            public Tuple4<Integer, Long, T1, PriorityQueue<Tuple2<Long, Float>>> map(Tuple3<Integer, Long, T1> tuple3) throws Exception {
                return Tuple4.of(tuple3.f0, tuple3.f1, tuple3.f2, new PriorityQueue(new Comparator<Tuple2<Long, Float>>() { // from class: com.alibaba.alink.operator.common.dataproc.BlockwiseCross.2.1
                    @Override // java.util.Comparator
                    public int compare(Tuple2<Long, Float> tuple2, Tuple2<Long, Float> tuple22) {
                        if (order == Order.DESCENDING) {
                            return Float.compare(((Float) tuple2.f1).floatValue(), ((Float) tuple22.f1).floatValue());
                        }
                        if (order == Order.ASCENDING) {
                            return Float.compare(((Float) tuple22.f1).floatValue(), ((Float) tuple2.f1).floatValue());
                        }
                        throw new AkUnsupportedOperationException("Not supported order type: " + order);
                    }
                }));
            }
        }).returns(tupleTypeInfo).withForwardedFields(new String[]{"f0;f1;f2"}).iterate(parallelism);
        return iterate.closeWith(iterate.coGroup(appendTaskId2.map(new RichMapFunction<Tuple3<Integer, Long, T2>, Tuple3<Integer, Long, T2>>() { // from class: com.alibaba.alink.operator.common.dataproc.BlockwiseCross.4
            private static final long serialVersionUID = 8030878174094151551L;
            private transient int shift;

            public void open(Configuration configuration) throws Exception {
                this.shift = ((Integer) getRuntimeContext().getBroadcastVariable("shift").get(0)).intValue();
            }

            public Tuple3<Integer, Long, T2> map(Tuple3<Integer, Long, T2> tuple3) throws Exception {
                return Tuple3.of(Integer.valueOf((((Integer) tuple3.f0).intValue() + this.shift) % parallelism), tuple3.f1, tuple3.f2);
            }
        }).withBroadcastSet(iterate.mapPartition(new RichMapPartitionFunction<Tuple4<Integer, Long, T1, PriorityQueue<Tuple2<Long, Float>>>, Integer>() { // from class: com.alibaba.alink.operator.common.dataproc.BlockwiseCross.3
            private static final long serialVersionUID = -8395681840329482945L;

            public void mapPartition(Iterable<Tuple4<Integer, Long, T1, PriorityQueue<Tuple2<Long, Float>>>> iterable, Collector<Integer> collector) throws Exception {
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    collector.collect(Integer.valueOf(getIterationRuntimeContext().getSuperstepNumber() - 1));
                }
            }
        }).returns(Types.INT), "shift").returns(appendTaskId2.getType()).withForwardedFields(new String[]{"f1;f2"})).where(new int[]{0}).equalTo(new int[]{0}).withPartitioner(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.dataproc.BlockwiseCross.6
            private static final long serialVersionUID = 1382272229444620156L;

            public int partition(Integer num, int i2) {
                return num.intValue() % i2;
            }
        }).with(new RichCoGroupFunction<Tuple4<Integer, Long, T1, PriorityQueue<Tuple2<Long, Float>>>, Tuple3<Integer, Long, T2>, Tuple4<Integer, Long, T1, PriorityQueue<Tuple2<Long, Float>>>>() { // from class: com.alibaba.alink.operator.common.dataproc.BlockwiseCross.5
            private static final long serialVersionUID = -7110970239572056505L;

            /* JADX WARN: Multi-variable type inference failed */
            public void coGroup(Iterable<Tuple4<Integer, Long, T1, PriorityQueue<Tuple2<Long, Float>>>> iterable, Iterable<Tuple3<Integer, Long, T2>> iterable2, Collector<Tuple4<Integer, Long, T1, PriorityQueue<Tuple2<Long, Float>>>> collector) throws Exception {
                if (iterable == null) {
                    return;
                }
                if (iterable2 == null) {
                    collector.getClass();
                    iterable.forEach((v1) -> {
                        r1.collect(v1);
                    });
                    return;
                }
                long currentTimeMillis = System.currentTimeMillis();
                BulkScoreFunction.this.addTargets(iterable2);
                double d = 0.0d;
                double d2 = 0.0d;
                int i2 = 0;
                int i3 = 0;
                for (Tuple4<Integer, Long, T1, PriorityQueue<Tuple2<Long, Float>>> tuple4 : iterable) {
                    long currentTimeMillis2 = System.currentTimeMillis();
                    i2++;
                    List<Tuple2<Long, Float>> scoreAll = BulkScoreFunction.this.scoreAll((Long) tuple4.f1, tuple4.f2);
                    i3 = scoreAll.size();
                    long currentTimeMillis3 = System.currentTimeMillis();
                    PriorityQueue priorityQueue = (PriorityQueue) tuple4.f3;
                    for (int i4 = 0; i4 < scoreAll.size(); i4++) {
                        float floatValue = ((Float) scoreAll.get(i4).f1).floatValue();
                        Long l = (Long) scoreAll.get(i4).f0;
                        if (priorityQueue.size() < i) {
                            priorityQueue.add(Tuple2.of(l, Float.valueOf(floatValue)));
                        } else if ((order == Order.DESCENDING && floatValue > ((Float) ((Tuple2) priorityQueue.peek()).f1).floatValue()) || (order == Order.ASCENDING && floatValue < ((Float) ((Tuple2) priorityQueue.peek()).f1).floatValue())) {
                            priorityQueue.poll();
                            priorityQueue.add(Tuple2.of(l, Float.valueOf(floatValue)));
                        }
                    }
                    d += 0.001d * (currentTimeMillis3 - currentTimeMillis2);
                    d2 += 0.001d * (System.currentTimeMillis() - currentTimeMillis3);
                    collector.collect(tuple4);
                }
                BlockwiseCross.LOG.info("Done local cross in {}s, # records {}, # targets {}", new Object[]{Double.valueOf((System.currentTimeMillis() - currentTimeMillis) * 0.001d), Integer.valueOf(i2), Integer.valueOf(i3)});
                BlockwiseCross.LOG.info("Wall time: score {}s, enqueue {}s", Double.valueOf(d), Double.valueOf(d2));
            }
        }).returns(tupleTypeInfo).name("block_cross")).map(new MapFunction<Tuple4<Integer, Long, T1, PriorityQueue<Tuple2<Long, Float>>>, Tuple3<Long, long[], float[]>>() { // from class: com.alibaba.alink.operator.common.dataproc.BlockwiseCross.8
            private static final long serialVersionUID = -281332970706889768L;

            public Tuple3<Long, long[], float[]> map(Tuple4<Integer, Long, T1, PriorityQueue<Tuple2<Long, Float>>> tuple4) throws Exception {
                PriorityQueue priorityQueue = (PriorityQueue) tuple4.f3;
                long[] jArr = new long[priorityQueue.size()];
                float[] fArr = new float[priorityQueue.size()];
                int size = priorityQueue.size() - 1;
                while (priorityQueue.size() > 0) {
                    Tuple2 tuple2 = (Tuple2) priorityQueue.poll();
                    jArr[size] = ((Long) tuple2.f0).longValue();
                    fArr[size] = ((Float) tuple2.f1).floatValue();
                    size--;
                }
                return Tuple3.of(tuple4.f1, jArr, fArr);
            }
        }).returns(new TypeHint<Tuple3<Long, long[], float[]>>() { // from class: com.alibaba.alink.operator.common.dataproc.BlockwiseCross.7
        });
    }

    private static <T> DataSet<Tuple3<Integer, Long, T>> appendTaskId(DataSet<Tuple2<Long, T>> dataSet) {
        return dataSet.map(new RichMapFunction<Tuple2<Long, T>, Tuple3<Integer, Long, T>>() { // from class: com.alibaba.alink.operator.common.dataproc.BlockwiseCross.9
            private static final long serialVersionUID = 148870014283375243L;
            private transient int taskId;

            public void open(Configuration configuration) throws Exception {
                this.taskId = getRuntimeContext().getIndexOfThisSubtask();
            }

            public Tuple3<Integer, Long, T> map(Tuple2<Long, T> tuple2) throws Exception {
                return Tuple3.of(Integer.valueOf(this.taskId), tuple2.f0, tuple2.f1);
            }
        }).returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.LONG, dataSet.getType().getTypeAt(1)})).withForwardedFields(new String[]{"f0->f1;f1->f2"});
    }
}
