package com.alibaba.alink.operator.common.associationrule;

import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.PriorityQueue;
import javax.annotation.Nullable;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichCoGroupFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/associationrule/ParallelFpGrowth.class */
public class ParallelFpGrowth {
    private static final Logger LOG = LoggerFactory.getLogger(ParallelFpGrowth.class);
    private DataSet<int[]> transactions;
    private final DataSet<Tuple2<Integer, Integer>> itemCounts;
    private DataSet<Long> minSupportCnt;
    private final FpTree fpTree;
    private final int maxPatternLength;

    public ParallelFpGrowth(@Nullable FpTree fpTree, DataSet<int[]> dataSet, DataSet<Tuple2<Integer, Integer>> dataSet2, DataSet<Long> dataSet3, int i) {
        if (fpTree == null) {
            this.fpTree = new FpTreeImpl();
        } else {
            this.fpTree = fpTree;
        }
        this.transactions = dataSet;
        this.minSupportCnt = dataSet3;
        this.maxPatternLength = i;
        this.itemCounts = dataSet2;
    }

    public DataSet<Tuple2<int[], Integer>> run() {
        DataSet<Tuple3<Integer, Integer, Integer>> partitionItems = partitionItems(this.itemCounts);
        return mineFreqItemsets(genCondTransactions(this.transactions, partitionItems), partitionItems.project(new int[]{2, 0}), this.minSupportCnt, this.maxPatternLength, this.fpTree);
    }

    private static DataSet<Tuple3<Integer, Integer, Integer>> partitionItems(DataSet<Tuple2<Integer, Integer>> dataSet) {
        return dataSet.rebalance().mapPartition(new RichMapPartitionFunction<Tuple2<Integer, Integer>, Tuple3<Integer, Integer, Integer>>() { // from class: com.alibaba.alink.operator.common.associationrule.ParallelFpGrowth.1
            private static final long serialVersionUID = 522893987379062693L;

            public void mapPartition(Iterable<Tuple2<Integer, Integer>> iterable, Collector<Tuple3<Integer, Integer, Integer>> collector) throws Exception {
                int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                ArrayList arrayList = new ArrayList();
                Iterator<Tuple2<Integer, Integer>> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                arrayList.sort((tuple2, tuple22) -> {
                    int compare = Long.compare(((Integer) tuple22.f1).intValue(), ((Integer) tuple2.f1).intValue());
                    return compare == 0 ? Integer.compare(((Integer) tuple2.f0).intValue(), ((Integer) tuple22.f0).intValue()) : compare;
                });
                PriorityQueue priorityQueue = new PriorityQueue(numberOfParallelSubtasks, Comparator.comparingDouble(tuple23 -> {
                    return ((Double) tuple23.f1).doubleValue();
                }));
                for (int i = 0; i < numberOfParallelSubtasks; i++) {
                    priorityQueue.add(Tuple2.of(Integer.valueOf(i), Double.valueOf(Criteria.INVALID_GAIN)));
                }
                ArrayList arrayList2 = new ArrayList(arrayList.size());
                for (int i2 = 0; i2 < arrayList.size(); i2++) {
                    arrayList2.add(Double.valueOf((i2 / arrayList.size()) * ((Integer) ((Tuple2) arrayList.get(i2)).f1).doubleValue()));
                }
                ArrayList arrayList3 = new ArrayList(arrayList.size());
                for (int i3 = 0; i3 < arrayList.size(); i3++) {
                    arrayList3.add(Integer.valueOf(i3));
                }
                arrayList3.sort((num, num2) -> {
                    return Double.compare(((Double) arrayList2.get(num2.intValue())).doubleValue(), ((Double) arrayList2.get(num.intValue())).doubleValue());
                });
                for (int i4 = 0; i4 < arrayList.size(); i4++) {
                    Tuple2 tuple24 = (Tuple2) arrayList.get(((Integer) arrayList3.get(i4)).intValue());
                    double doubleValue = ((Double) arrayList2.get(((Integer) arrayList3.get(i4)).intValue())).doubleValue();
                    Tuple2 tuple25 = (Tuple2) priorityQueue.poll();
                    int intValue = ((Integer) tuple25.f0).intValue();
                    tuple25.f1 = Double.valueOf(((Double) tuple25.f1).doubleValue() + doubleValue);
                    priorityQueue.add(tuple25);
                    collector.collect(Tuple3.of(tuple24.f0, arrayList3.get(i4), Integer.valueOf(intValue)));
                }
            }
        }).setParallelism(1).name("create_partitioner");
    }

    private static DataSet<Tuple2<Integer, int[]>> genCondTransactions(DataSet<int[]> dataSet, DataSet<Tuple3<Integer, Integer, Integer>> dataSet2) {
        return dataSet.rebalance().flatMap(new RichFlatMapFunction<int[], Tuple2<Integer, int[]>>() { // from class: com.alibaba.alink.operator.common.associationrule.ParallelFpGrowth.2
            private static final long serialVersionUID = -1661328067465216249L;
            final int ITEM_PARTITION = 1;
            transient Map<Integer, Tuple2<Integer, Integer>> itemPartitioner;
            transient int[] flags;

            public void open(Configuration configuration) throws Exception {
                int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                this.itemPartitioner = (Map) getRuntimeContext().getBroadcastVariableWithInitializer("partitioner", new BroadcastVariableInitializer<Tuple3<Integer, Integer, Integer>, Map<Integer, Tuple2<Integer, Integer>>>() { // from class: com.alibaba.alink.operator.common.associationrule.ParallelFpGrowth.2.1
                    public Map<Integer, Tuple2<Integer, Integer>> initializeBroadcastVariable(Iterable<Tuple3<Integer, Integer, Integer>> iterable) {
                        HashMap hashMap = new HashMap();
                        for (Tuple3<Integer, Integer, Integer> tuple3 : iterable) {
                            hashMap.put(tuple3.f0, Tuple2.of(tuple3.f1, tuple3.f2));
                        }
                        return hashMap;
                    }

                    /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                    public /* bridge */ /* synthetic */ Object m321initializeBroadcastVariable(Iterable iterable) {
                        return initializeBroadcastVariable((Iterable<Tuple3<Integer, Integer, Integer>>) iterable);
                    }
                });
                this.flags = new int[numberOfParallelSubtasks];
            }

            public void flatMap(int[] iArr, Collector<Tuple2<Integer, int[]>> collector) throws Exception {
                Arrays.fill(this.flags, 0);
                int length = iArr.length;
                for (int i = 0; i < length; i++) {
                    int i2 = length - i;
                    int intValue = ((Integer) this.itemPartitioner.get(Integer.valueOf(iArr[i2 - 1])).getField(1)).intValue();
                    if (this.flags[intValue] == 0) {
                        ArrayList arrayList = new ArrayList(i2);
                        for (int i3 = 0; i3 < i2; i3++) {
                            arrayList.add(Integer.valueOf(iArr[i3]));
                        }
                        int[] iArr2 = new int[arrayList.size()];
                        for (int i4 = 0; i4 < iArr2.length; i4++) {
                            iArr2[i4] = ((Integer) arrayList.get(i4)).intValue();
                        }
                        collector.collect(Tuple2.of(Integer.valueOf(intValue), iArr2));
                        this.flags[intValue] = 1;
                    }
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((int[]) obj, (Collector<Tuple2<Integer, int[]>>) collector);
            }
        }).withBroadcastSet(dataSet2, "partitioner").name("gen_cond_transactions");
    }

    private static DataSet<Tuple2<int[], Integer>> mineFreqItemsets(DataSet<Tuple2<Integer, int[]>> dataSet, DataSet<Tuple2<Integer, Integer>> dataSet2, DataSet<Long> dataSet3, final int i, final FpTree fpTree) {
        return dataSet.coGroup(dataSet2).where(new int[]{0}).equalTo(new int[]{0}).withPartitioner(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.associationrule.ParallelFpGrowth.5
            private static final long serialVersionUID = -7286445841896908139L;

            public int partition(Integer num, int i2) {
                return num.intValue() % i2;
            }
        }).with(new RichCoGroupFunction<Tuple2<Integer, int[]>, Tuple2<Integer, Integer>, Tuple2<int[], Integer>>() { // from class: com.alibaba.alink.operator.common.associationrule.ParallelFpGrowth.4
            private static final long serialVersionUID = 1429246682769325224L;

            public void coGroup(Iterable<Tuple2<Integer, int[]>> iterable, Iterable<Tuple2<Integer, Integer>> iterable2, Collector<Tuple2<int[], Integer>> collector) throws Exception {
                long longValue = ((Long) getRuntimeContext().getBroadcastVariable("minSupportCnt").get(0)).longValue();
                ParallelFpGrowth.LOG.info("minSupportCnt = {}", Long.valueOf(longValue));
                long currentTimeMillis = System.currentTimeMillis();
                FpTree.this.createTree();
                Iterator<Tuple2<Integer, int[]>> it = iterable.iterator();
                while (it.hasNext()) {
                    FpTree.this.addTransaction((int[]) it.next().f1);
                }
                FpTree.this.initialize();
                FpTree.this.printProfile();
                ArrayList arrayList = new ArrayList();
                arrayList.getClass();
                iterable2.forEach((v1) -> {
                    r1.add(v1);
                });
                int[] iArr = new int[arrayList.size()];
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    iArr[i2] = ((Integer) ((Tuple2) arrayList.get(i2)).f1).intValue();
                }
                FpTree.this.extractAll(iArr, (int) longValue, i, collector);
                FpTree.this.destroyTree();
                ParallelFpGrowth.LOG.info("Done local FpGrowth in {}s.", Long.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000));
            }
        }).withBroadcastSet(dataSet3, "minSupportCnt").name("fpgrowth").map(new MapFunction<Tuple2<int[], Integer>, Tuple2<int[], Integer>>() { // from class: com.alibaba.alink.operator.common.associationrule.ParallelFpGrowth.3
            private static final long serialVersionUID = -3684281661645139365L;

            public Tuple2<int[], Integer> map(Tuple2<int[], Integer> tuple2) throws Exception {
                int[] iArr = (int[]) tuple2.f0;
                Arrays.sort(iArr);
                return Tuple2.of(iArr, tuple2.f1);
            }
        });
    }
}
