package com.alibaba.alink.operator.batch.associationrule;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.associationrule.FpTreeImpl;
import com.alibaba.alink.params.associationrule.GroupedFpGrowthParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.StringUtils;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH)})
@OutputPorts(values = {@PortSpec(value = PortType.MODEL, desc = PortDesc.ASSOCIATION_PATTERNS), @PortSpec(value = PortType.MODEL, desc = PortDesc.ASSOCIATION_RULES)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "itemsCol", allowedTypeCollections = {TypeCollections.STRING_TYPE}), @ParamSelectColumnSpec(name = "groupCol")})
@NameCn("分组FPGrowth训练")
@NameEn("Grouped FpGrowth Training")
/* loaded from: input_file:com/alibaba/alink/operator/batch/associationrule/GroupedFpGrowthBatchOp.class */
public final class GroupedFpGrowthBatchOp extends BatchOperator<GroupedFpGrowthBatchOp> implements GroupedFpGrowthParams<GroupedFpGrowthBatchOp> {
    private static final long serialVersionUID = -3434563610385164063L;

    public GroupedFpGrowthBatchOp() {
    }

    public GroupedFpGrowthBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public GroupedFpGrowthBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String itemsCol = getItemsCol();
        String groupCol = getGroupCol();
        final int intValue = getMinSupportCount().intValue();
        final double doubleValue = getMinSupportPercent().doubleValue();
        final int intValue2 = getMaxPatternLength().intValue();
        final double doubleValue2 = getMinLift().doubleValue();
        final double doubleValue3 = getMinConfidence().doubleValue();
        BatchOperator<?> select = checkAndGetFirst.select(new String[]{groupCol, itemsCol});
        MapOperator map = select.getDataSet().map(new MapFunction<Row, Tuple2<String, Row>>() { // from class: com.alibaba.alink.operator.batch.associationrule.GroupedFpGrowthBatchOp.1
            private static final long serialVersionUID = -2408050110594809903L;

            public Tuple2<String, Row> map(Row row) throws Exception {
                return Tuple2.of(String.valueOf(row.getField(0)), row);
            }
        });
        ReduceOperator reduce = map.map(new MapFunction<Tuple2<String, Row>, Tuple2<String, Long>>() { // from class: com.alibaba.alink.operator.batch.associationrule.GroupedFpGrowthBatchOp.3
            private static final long serialVersionUID = -9076532747012786151L;

            public Tuple2<String, Long> map(Tuple2<String, Row> tuple2) throws Exception {
                return Tuple2.of(tuple2.f0, 1L);
            }
        }).groupBy(new int[]{0}).reduce(new ReduceFunction<Tuple2<String, Long>>() { // from class: com.alibaba.alink.operator.batch.associationrule.GroupedFpGrowthBatchOp.2
            private static final long serialVersionUID = 6511332413000462792L;

            public Tuple2<String, Long> reduce(Tuple2<String, Long> tuple2, Tuple2<String, Long> tuple22) throws Exception {
                tuple2.f1 = Long.valueOf(((Long) tuple2.f1).longValue() + ((Long) tuple22.f1).longValue());
                return tuple2;
            }
        });
        Operator name = map.groupBy(new int[]{0}).reduceGroup(new RichGroupReduceFunction<Tuple2<String, Row>, Tuple2<String, Row>>() { // from class: com.alibaba.alink.operator.batch.associationrule.GroupedFpGrowthBatchOp.4
            private static final long serialVersionUID = -2758244844753267506L;

            public void reduce(Iterable<Tuple2<String, Row>> iterable, final Collector<Tuple2<String, Row>> collector) throws Exception {
                Object obj = null;
                ArrayList<String> arrayList = new ArrayList();
                for (Tuple2<String, Row> tuple2 : iterable) {
                    obj = ((Row) tuple2.f1).getField(0);
                    arrayList.add((String) ((Row) tuple2.f1).getField(1));
                }
                final Object obj2 = obj;
                int decideMinSupportCount = GroupedFpGrowthBatchOp.decideMinSupportCount(intValue, doubleValue, arrayList.size());
                Map itemCounts = GroupedFpGrowthBatchOp.getItemCounts(arrayList);
                Tuple2 orderItems = GroupedFpGrowthBatchOp.orderItems(itemCounts);
                Map map2 = (Map) orderItems.f0;
                final List list = (List) orderItems.f1;
                int[] qualifiedItemIndices = GroupedFpGrowthBatchOp.getQualifiedItemIndices(itemCounts, map2, decideMinSupportCount);
                FpTreeImpl fpTreeImpl = new FpTreeImpl();
                fpTreeImpl.createTree();
                for (String str : arrayList) {
                    if (!StringUtils.isNullOrWhitespaceOnly(str)) {
                        String[] split = str.split(",");
                        HashSet hashSet = new HashSet(split.length);
                        for (String str2 : split) {
                            if (((Integer) itemCounts.get(str2)).intValue() >= decideMinSupportCount) {
                                hashSet.add(map2.get(str2));
                            }
                        }
                        int[] array = GroupedFpGrowthBatchOp.toArray(hashSet);
                        Arrays.sort(array);
                        fpTreeImpl.addTransaction(array);
                    }
                }
                fpTreeImpl.initialize();
                fpTreeImpl.printProfile();
                fpTreeImpl.extractAll(qualifiedItemIndices, decideMinSupportCount, intValue2, new Collector<Tuple2<int[], Integer>>() { // from class: com.alibaba.alink.operator.batch.associationrule.GroupedFpGrowthBatchOp.4.1
                    public void collect(Tuple2<int[], Integer> tuple22) {
                        String indicesToTokens = GroupedFpGrowthBatchOp.indicesToTokens((int[]) tuple22.f0, list);
                        long intValue3 = ((Integer) tuple22.f1).intValue();
                        long length = ((int[]) tuple22.f0).length;
                        Row row = new Row(FpGrowthBatchOp.ITEMSETS_COL_NAMES.length + 1);
                        row.setField(0, obj2);
                        row.setField(1, indicesToTokens);
                        row.setField(2, Long.valueOf(intValue3));
                        row.setField(3, Long.valueOf(length));
                        collector.collect(Tuple2.of(String.valueOf(obj2), row));
                    }

                    public void close() {
                    }
                });
                fpTreeImpl.destroyTree();
            }
        }).name("gen_patterns");
        Operator name2 = name.groupBy(new int[]{0}).reduceGroup(new RichGroupReduceFunction<Tuple2<String, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.associationrule.GroupedFpGrowthBatchOp.5
            private static final long serialVersionUID = 3848228188758749261L;
            transient List<Tuple2<String, Long>> bc;

            public void open(Configuration configuration) throws Exception {
                this.bc = getRuntimeContext().getBroadcastVariable("groupCount");
            }

            public void reduce(Iterable<Tuple2<String, Row>> iterable, Collector<Row> collector) throws Exception {
                HashMap hashMap = new HashMap();
                String str = null;
                for (Tuple2<String, Row> tuple2 : iterable) {
                    str = (String) tuple2.f0;
                    hashMap.put((String) ((Row) tuple2.f1).getField(1), (Long) ((Row) tuple2.f1).getField(2));
                }
                String str2 = str;
                Long l = null;
                Iterator<Tuple2<String, Long>> it = this.bc.iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    Tuple2<String, Long> next = it.next();
                    if (((String) next.f0).equals(str)) {
                        l = (Long) next.f1;
                        break;
                    }
                }
                Preconditions.checkArgument(l != null);
                Long l2 = l;
                double d = doubleValue2;
                double d2 = doubleValue3;
                hashMap.forEach((str3, l3) -> {
                    String[] split = str3.split(",");
                    if (split.length > 1) {
                        for (int i = 0; i < split.length; i++) {
                            int i2 = 0;
                            StringBuilder sb = new StringBuilder();
                            for (int i3 = 0; i3 < split.length; i3++) {
                                if (i3 != i) {
                                    if (i2 > 0) {
                                        sb.append(",");
                                    }
                                    sb.append(split[i3]);
                                    i2++;
                                }
                            }
                            String sb2 = sb.toString();
                            String str3 = split[i];
                            Long l3 = (Long) hashMap.get(sb2);
                            Long l4 = (Long) hashMap.get(str3);
                            Preconditions.checkArgument(l3 != null);
                            Preconditions.checkArgument(l4 != null);
                            Preconditions.checkArgument(l3 != null);
                            double doubleValue4 = l3.doubleValue() / l3.doubleValue();
                            double doubleValue5 = (l3.doubleValue() * l2.doubleValue()) / (l3.doubleValue() * l4.doubleValue());
                            double doubleValue6 = l3.doubleValue() / l2.doubleValue();
                            Row row = new Row(7);
                            row.setField(0, str2);
                            row.setField(1, sb2 + PrefixSpanBatchOp.RULE_SEPARATOR + str3);
                            row.setField(2, Long.valueOf(split.length));
                            row.setField(3, Double.valueOf(doubleValue5));
                            row.setField(4, Double.valueOf(doubleValue6));
                            row.setField(5, Double.valueOf(doubleValue4));
                            row.setField(6, l3);
                            if (doubleValue5 >= d && doubleValue4 >= d2) {
                                collector.collect(row);
                            }
                        }
                    }
                });
            }
        }).withBroadcastSet(reduce, "groupCount").name("gen_rules");
        Table table = DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) name.map(new MapFunction<Tuple2<String, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.associationrule.GroupedFpGrowthBatchOp.6
            private static final long serialVersionUID = -4247869441801301592L;

            public Row map(Tuple2<String, Row> tuple2) throws Exception {
                return (Row) tuple2.f1;
            }
        }), (String[]) ArrayUtils.addAll(new String[]{select.getColNames()[0]}, FpGrowthBatchOp.ITEMSETS_COL_NAMES), (TypeInformation<?>[]) ArrayUtils.addAll(new TypeInformation[]{select.getColTypes()[0]}, FpGrowthBatchOp.ITEMSETS_COL_TYPES));
        Table table2 = DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) name2, (String[]) ArrayUtils.addAll(new String[]{select.getColNames()[0]}, FpGrowthBatchOp.RULES_COL_NAMES), (TypeInformation<?>[]) ArrayUtils.addAll(new TypeInformation[]{select.getColTypes()[0]}, FpGrowthBatchOp.RULES_COL_TYPES));
        setOutputTable(table);
        setSideOutputTables(new Table[]{table2});
        return this;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int decideMinSupportCount(int i, double d, int i2) {
        return i >= 0 ? i : (int) Math.floor(i2 * d);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<String, Integer> getItemCounts(List<String> list) {
        HashMap hashMap = new HashMap();
        for (String str : list) {
            if (!StringUtils.isNullOrWhitespaceOnly(str)) {
                String[] split = str.split(",");
                HashSet hashSet = new HashSet();
                hashSet.addAll(Arrays.asList(split));
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    hashMap.merge((String) it.next(), 1, (num, num2) -> {
                        return Integer.valueOf(num.intValue() + num2.intValue());
                    });
                }
            }
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Tuple2<Map<String, Integer>, List<String>> orderItems(final Map<String, Integer> map) {
        ArrayList arrayList = new ArrayList(map.size());
        map.forEach((str, num) -> {
            arrayList.add(str);
        });
        arrayList.sort(new Comparator<String>() { // from class: com.alibaba.alink.operator.batch.associationrule.GroupedFpGrowthBatchOp.7
            @Override // java.util.Comparator
            public int compare(String str2, String str3) {
                return Integer.compare(((Integer) map.get(str3)).intValue(), ((Integer) map.get(str2)).intValue());
            }
        });
        HashMap hashMap = new HashMap(map.size());
        for (int i = 0; i < arrayList.size(); i++) {
            hashMap.put(arrayList.get(i), Integer.valueOf(i));
        }
        return Tuple2.of(hashMap, arrayList);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int[] toArray(Set<Integer> set) {
        int[] iArr = new int[set.size()];
        int i = 0;
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            iArr[i2] = it.next().intValue();
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int[] getQualifiedItemIndices(Map<String, Integer> map, Map<String, Integer> map2, int i) {
        ArrayList arrayList = new ArrayList();
        map.forEach((str, num) -> {
            if (num.intValue() >= i) {
                arrayList.add(str);
            }
        });
        int[] iArr = new int[arrayList.size()];
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            iArr[i2] = map2.get(arrayList.get(i2)).intValue();
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String indicesToTokens(int[] iArr, List<String> list) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < iArr.length; i++) {
            if (i > 0) {
                sb.append(",");
            }
            sb.append(list.get(iArr[i]));
        }
        return sb.toString();
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ GroupedFpGrowthBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
