package com.alibaba.alink.operator.common.associationrule;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/associationrule/ParallelPrefixSpan.class */
public class ParallelPrefixSpan {
    private DataSet<int[]> sequences;
    private DataSet<Long> minSupportCnt;
    private DataSet<Tuple2<Integer, Integer>> itemCounts;
    private int maxPatternLength;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/associationrule/ParallelPrefixSpan$Node.class */
    public static class Node {
        Prefix prefix;
        List<Postfix> projectedDB;
        Integer[] nextPrefixItems;
        int numFinished = 0;

        public Node(Prefix prefix, List<Postfix> list) {
            this.prefix = prefix;
            this.projectedDB = list;
        }

        boolean hasFinished() {
            return this.numFinished >= this.nextPrefixItems.length;
        }

        void emitResult(Collector<Tuple2<int[], Integer>> collector) {
            collector.collect(Tuple2.of(this.prefix.items, Integer.valueOf(this.projectedDB.size())));
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("Node[prefix=").append(this.prefix.toString());
            this.projectedDB.forEach(postfix -> {
                sb.append(",postfix=").append(postfix.toString());
            });
            sb.append(",nextPrefixItems=").append(Tuple1.of(this.nextPrefixItems));
            sb.append(",numFinished=").append(this.numFinished).append("/").append(this.nextPrefixItems.length);
            sb.append("]");
            return sb.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/associationrule/ParallelPrefixSpan$Postfix.class */
    public static class Postfix {
        int sequenceId;
        int start;
        Integer[] partialStarts;

        Postfix(int i) {
            this.sequenceId = i;
            this.start = 0;
            this.partialStarts = new Integer[0];
        }

        Postfix(int i, int i2, Integer[] numArr) {
            this.sequenceId = i;
            this.start = i2;
            this.partialStarts = numArr;
        }

        static List<Postfix> projectAll(List<Postfix> list, List<int[]> list2, int i) {
            ArrayList arrayList = new ArrayList();
            list.forEach(postfix -> {
                Postfix project = postfix.project(list2, i);
                if (project != null) {
                    arrayList.add(project);
                }
            });
            return arrayList;
        }

        static Integer[] getAllNextPrefixItems(List<int[]> list, List<Postfix> list2, long j) {
            HashMap hashMap = new HashMap();
            list2.forEach(postfix -> {
                for (Integer num : postfix.genPrefixItems(list)) {
                    hashMap.merge(num, 1, (num2, num3) -> {
                        return Integer.valueOf(num2.intValue() + num3.intValue());
                    });
                }
            });
            ArrayList arrayList = new ArrayList();
            hashMap.forEach((num, num2) -> {
                if (num2.intValue() >= j) {
                    arrayList.add(num);
                }
            });
            return (Integer[]) arrayList.toArray(new Integer[0]);
        }

        public Postfix project(List<int[]> list, int i) {
            int i2;
            int[] iArr = list.get(this.sequenceId);
            int length = iArr.length - 1;
            boolean z = false;
            int i3 = length;
            ArrayList arrayList = new ArrayList();
            if (i > 0) {
                int i4 = this.start;
                while (iArr[i4] != 0) {
                    i4++;
                }
                for (int i5 = i4; i5 < length; i5++) {
                    if (iArr[i5] == i) {
                        if (!z) {
                            i3 = i5;
                            z = true;
                        }
                        if (iArr[i5 + 1] != 0) {
                            arrayList.add(Integer.valueOf(i5 + 1));
                        }
                    }
                }
            } else if (i < 0) {
                int i6 = -i;
                for (Integer num : this.partialStarts) {
                    int intValue = num.intValue();
                    int i7 = iArr[intValue];
                    while (true) {
                        i2 = i7;
                        if (i2 == i6 || i2 == 0) {
                            break;
                        }
                        intValue++;
                        i7 = iArr[intValue];
                    }
                    if (i2 == i6) {
                        int i8 = intValue + 1;
                        if (!z) {
                            i3 = i8;
                            z = true;
                        }
                        if (iArr[i8] != 0) {
                            arrayList.add(Integer.valueOf(i8));
                        }
                    }
                }
            }
            if (z) {
                return new Postfix(this.sequenceId, i3, (Integer[]) arrayList.toArray(new Integer[0]));
            }
            return null;
        }

        Integer[] genPrefixItems(List<int[]> list) {
            int[] iArr = list.get(this.sequenceId);
            long length = iArr.length - 1;
            HashSet hashSet = new HashSet();
            for (Integer num : this.partialStarts) {
                int intValue = num.intValue();
                int i = iArr[intValue];
                while (true) {
                    int i2 = -i;
                    if (i2 != 0) {
                        if (!hashSet.contains(Integer.valueOf(i2))) {
                            hashSet.add(Integer.valueOf(i2));
                        }
                        intValue++;
                        i = iArr[intValue];
                    }
                }
            }
            int i3 = this.start;
            while (iArr[i3] != 0) {
                i3++;
            }
            for (int i4 = i3; i4 < length; i4++) {
                int i5 = iArr[i4];
                if (i5 != 0 && !hashSet.contains(Integer.valueOf(i5))) {
                    hashSet.add(Integer.valueOf(i5));
                }
            }
            return (Integer[]) hashSet.toArray(new Integer[0]);
        }

        public String toString() {
            return Tuple3.of(Integer.valueOf(this.sequenceId), Integer.valueOf(this.start), this.partialStarts).toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/associationrule/ParallelPrefixSpan$Prefix.class */
    public static class Prefix {
        int[] items;
        int length;

        Prefix(int[] iArr, int i) {
            this.items = iArr;
            this.length = i;
        }

        public Prefix expand(int i) {
            if (i < 0) {
                int[] iArr = new int[this.items.length + 1];
                System.arraycopy(this.items, 0, iArr, 0, this.items.length - 1);
                iArr[this.items.length - 1] = -i;
                iArr[this.items.length] = 0;
                return new Prefix(iArr, this.length + 1);
            }
            int[] iArr2 = new int[this.items.length + 2];
            System.arraycopy(this.items, 0, iArr2, 0, this.items.length);
            iArr2[this.items.length] = i;
            iArr2[this.items.length + 1] = 0;
            return new Prefix(iArr2, this.length + 1);
        }

        public String toString() {
            return Tuple1.of(this.items).toString();
        }
    }

    public ParallelPrefixSpan(DataSet<int[]> dataSet, DataSet<Long> dataSet2, DataSet<Tuple2<Integer, Integer>> dataSet3, int i) {
        this.sequences = dataSet.rebalance();
        this.minSupportCnt = dataSet2;
        this.itemCounts = dataSet3;
        this.maxPatternLength = i;
    }

    public DataSet<Tuple2<int[], Integer>> run() {
        DataSet<Tuple2<Integer, int[]>> partitionSequence = partitionSequence(this.sequences, this.itemCounts);
        final int i = this.maxPatternLength;
        return partitionSequence.partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.associationrule.ParallelPrefixSpan.2
            private static final long serialVersionUID = 5960751544160966750L;

            public int partition(Integer num, int i2) {
                return num.intValue() % i2;
            }
        }, 0).mapPartition(new RichMapPartitionFunction<Tuple2<Integer, int[]>, Tuple2<int[], Integer>>() { // from class: com.alibaba.alink.operator.common.associationrule.ParallelPrefixSpan.1
            private static final long serialVersionUID = -3876003522636592081L;

            public void mapPartition(Iterable<Tuple2<Integer, int[]>> iterable, Collector<Tuple2<int[], Integer>> collector) throws Exception {
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("minSupportCnt");
                List broadcastVariable2 = getRuntimeContext().getBroadcastVariable("itemCounts");
                int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                long longValue = ((Long) broadcastVariable.get(0)).longValue();
                ArrayList arrayList = new ArrayList();
                iterable.forEach(tuple2 -> {
                    arrayList.add(tuple2.f1);
                });
                ArrayList arrayList2 = new ArrayList(arrayList.size());
                for (int i2 = 0; i2 < arrayList.size(); i2++) {
                    arrayList2.add(new Postfix(i2));
                }
                int i3 = i;
                broadcastVariable2.forEach(tuple22 -> {
                    int intValue = ((Integer) tuple22.f0).intValue();
                    if (intValue % numberOfParallelSubtasks == indexOfThisSubtask) {
                        ParallelPrefixSpan.generateFreqPattern(arrayList, arrayList2, intValue, longValue, i3, collector);
                    }
                });
            }
        }).withBroadcastSet(this.minSupportCnt, "minSupportCnt").withBroadcastSet(this.itemCounts, "itemCounts").name("generate_freq_pattern");
    }

    private static DataSet<Tuple2<Integer, int[]>> partitionSequence(DataSet<int[]> dataSet, DataSet<Tuple2<Integer, Integer>> dataSet2) {
        return dataSet.flatMap(new RichFlatMapFunction<int[], Tuple2<Integer, int[]>>() { // from class: com.alibaba.alink.operator.common.associationrule.ParallelPrefixSpan.3
            private static final long serialVersionUID = -7729483302097030648L;
            transient Map<Integer, Integer> itemCounts;
            transient boolean[] flags;
            transient int numPartitions;
            static final /* synthetic */ boolean $assertionsDisabled;

            public void open(Configuration configuration) throws Exception {
                this.numPartitions = getRuntimeContext().getNumberOfParallelSubtasks();
                this.itemCounts = new HashMap();
                getRuntimeContext().getBroadcastVariable("itemCounts").forEach(tuple2 -> {
                });
                this.flags = new boolean[this.numPartitions];
            }

            public void flatMap(int[] iArr, Collector<Tuple2<Integer, int[]>> collector) throws Exception {
                if (!$assertionsDisabled && iArr.length != 0 && (iArr[0] != 0 || iArr[iArr.length - 1] != 0)) {
                    throw new AssertionError();
                }
                sort(iArr);
                Arrays.fill(this.flags, false);
                for (int i = 0; i < iArr.length; i++) {
                    if (iArr[i] != 0) {
                        int i2 = iArr[i] % this.numPartitions;
                        if (!this.flags[i2]) {
                            this.flags[i2] = true;
                            int[] copyOfRange = Arrays.copyOfRange(iArr, i - 1, iArr.length);
                            copyOfRange[0] = 0;
                            collector.collect(Tuple2.of(Integer.valueOf(i2), copyOfRange));
                        }
                    }
                }
            }

            private void sort(int[] iArr) {
                int i = 0;
                while (i < iArr.length) {
                    if (iArr[i] == 0) {
                        i++;
                    } else {
                        int i2 = i;
                        while (iArr[i2] != 0) {
                            i2++;
                        }
                        Arrays.sort(iArr, i, i2);
                        i = i2 + 1;
                    }
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((int[]) obj, (Collector<Tuple2<Integer, int[]>>) collector);
            }

            static {
                $assertionsDisabled = !ParallelPrefixSpan.class.desiredAssertionStatus();
            }
        }).withBroadcastSet(dataSet2, "itemCounts");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void generateFreqPattern(List<int[]> list, List<Postfix> list2, int i, long j, int i2, Collector<Tuple2<int[], Integer>> collector) {
        Stack stack = new Stack();
        Node node = new Node(new Prefix(new int[]{0, i, 0}, 1), Postfix.projectAll(list2, list, i));
        node.nextPrefixItems = Postfix.getAllNextPrefixItems(list, node.projectedDB, j);
        node.emitResult(collector);
        stack.push(node);
        while (!stack.empty()) {
            Node node2 = (Node) stack.peek();
            if (node2.hasFinished()) {
                stack.pop();
            } else {
                int intValue = node2.nextPrefixItems[node2.numFinished].intValue();
                Prefix expand = node2.prefix.expand(intValue);
                List<Postfix> projectAll = Postfix.projectAll(node2.projectedDB, list, intValue);
                Node node3 = new Node(expand, projectAll);
                node3.emitResult(collector);
                if (expand.length < i2) {
                    node3.nextPrefixItems = Postfix.getAllNextPrefixItems(list, projectAll, j);
                    stack.push(node3);
                }
                node2.numFinished++;
            }
        }
    }
}
