package com.alibaba.alink.operator.batch.associationrule;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.associationrule.AssociationRule;
import com.alibaba.alink.operator.common.associationrule.FpTree;
import com.alibaba.alink.operator.common.associationrule.ParallelFpGrowth;
import com.alibaba.alink.params.associationrule.FpGrowthParams;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFilterFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.JoinOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH)})
@OutputPorts(values = {@PortSpec(value = PortType.MODEL, desc = PortDesc.ASSOCIATION_PATTERNS), @PortSpec(value = PortType.MODEL, desc = PortDesc.ASSOCIATION_RULES)})
@ParamSelectColumnSpec(name = "itemsCol", allowedTypeCollections = {TypeCollections.STRING_TYPE})
@NameCn("FpGrowth")
@NameEn("FpGrowth")
/* loaded from: input_file:com/alibaba/alink/operator/batch/associationrule/FpGrowthBatchOp.class */
public final class FpGrowthBatchOp extends BatchOperator<FpGrowthBatchOp> implements FpGrowthParams<FpGrowthBatchOp> {
    public static final String ITEM_SEPARATOR = ",";
    private static final long serialVersionUID = -1737068631601351390L;
    private FpTree fpTree;
    private static final Logger LOG = LoggerFactory.getLogger(FpGrowthBatchOp.class);
    static final String[] ITEMSETS_COL_NAMES = {"itemset", "supportcount", "itemcount"};
    static final String[] RULES_COL_NAMES = {"rule", "itemcount", "lift", "support_percent", "confidence_percent", "transaction_count"};
    static final TypeInformation[] ITEMSETS_COL_TYPES = {Types.STRING, Types.LONG, Types.LONG};
    static final TypeInformation[] RULES_COL_TYPES = {Types.STRING, Types.LONG, Types.DOUBLE, Types.DOUBLE, Types.DOUBLE, Types.LONG};

    public FpGrowthBatchOp() {
        this(new Params());
    }

    public FpGrowthBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public FpGrowthBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String itemsCol = getItemsCol();
        int intValue = getMinSupportCount().intValue();
        double doubleValue = getMinSupportPercent().doubleValue();
        double doubleValue2 = getMinConfidence().doubleValue();
        int intValue2 = getMaxConsequentLength().intValue();
        double doubleValue3 = getMinLift().doubleValue();
        int intValue3 = getMaxPatternLength().intValue();
        final int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getSchema(), itemsCol);
        MapOperator map = checkAndGetFirst.getDataSet().map(new MapFunction<Row, Set<String>>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.1
            private static final long serialVersionUID = -2980403044705096669L;

            public Set<String> map(Row row) throws Exception {
                HashSet hashSet = new HashSet();
                String str = (String) row.getField(findColIndexWithAssertAndHint);
                if (!StringUtils.isNullOrWhitespaceOnly(str)) {
                    hashSet.addAll(Arrays.asList(str.split(",")));
                }
                return hashSet;
            }
        });
        DataSet<Long> count = count(map);
        DataSet<Long> minSupportCnt = getMinSupportCnt(count, intValue, doubleValue);
        ReduceOperator reduce = map.mapPartition(new MapPartitionFunction<Set<String>, Tuple2<String, Integer>>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.3
            private static final long serialVersionUID = -1276463140867640813L;

            public void mapPartition(Iterable<Set<String>> iterable, Collector<Tuple2<String, Integer>> collector) throws Exception {
                HashMap hashMap = new HashMap();
                Iterator<Set<String>> it = iterable.iterator();
                while (it.hasNext()) {
                    Iterator<String> it2 = it.next().iterator();
                    while (it2.hasNext()) {
                        hashMap.merge(it2.next(), 1, (v0, v1) -> {
                            return Integer.sum(v0, v1);
                        });
                    }
                }
                hashMap.forEach((str, num) -> {
                    collector.collect(Tuple2.of(str, num));
                });
            }
        }).groupBy(new int[]{0}).reduce(new ReduceFunction<Tuple2<String, Integer>>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.2
            private static final long serialVersionUID = 7570427842999572025L;

            public Tuple2<String, Integer> reduce(Tuple2<String, Integer> tuple2, Tuple2<String, Integer> tuple22) throws Exception {
                tuple2.f1 = Integer.valueOf(((Integer) tuple2.f1).intValue() + ((Integer) tuple22.f1).intValue());
                return tuple2;
            }
        });
        SingleInputUdfOperator withBroadcastSet = map.getExecutionEnvironment().fromElements(new Integer[]{0}).flatMap(new RichFlatMapFunction<Integer, Tuple2<String, Integer>>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.5
            private static final long serialVersionUID = -8493922755933898289L;

            public void flatMap(Integer num, Collector<Tuple2<String, Integer>> collector) throws Exception {
                final List broadcastVariable = getRuntimeContext().getBroadcastVariable("qualifiedItems");
                Integer[] numArr = new Integer[broadcastVariable.size()];
                for (int i = 0; i < numArr.length; i++) {
                    numArr[i] = Integer.valueOf(i);
                }
                Arrays.sort(numArr, new Comparator<Integer>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.5.1
                    @Override // java.util.Comparator
                    public int compare(Integer num2, Integer num3) {
                        Integer num4 = (Integer) ((Tuple2) broadcastVariable.get(num2.intValue())).f1;
                        Integer num5 = (Integer) ((Tuple2) broadcastVariable.get(num3.intValue())).f1;
                        return num4.equals(num5) ? ((String) ((Tuple2) broadcastVariable.get(num2.intValue())).f0).compareTo((String) ((Tuple2) broadcastVariable.get(num3.intValue())).f0) : Integer.compare(num5.intValue(), num4.intValue());
                    }
                });
                for (int i2 = 0; i2 < numArr.length; i2++) {
                    collector.collect(Tuple2.of(((Tuple2) broadcastVariable.get(numArr[i2].intValue())).f0, Integer.valueOf(i2)));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Integer) obj, (Collector<Tuple2<String, Integer>>) collector);
            }
        }).withBroadcastSet(reduce.filter(new RichFilterFunction<Tuple2<String, Integer>>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.4
            private static final long serialVersionUID = 7612046769658762503L;
            transient Long minSupportCount;

            public void open(Configuration configuration) throws Exception {
                this.minSupportCount = (Long) getRuntimeContext().getBroadcastVariable("minSupportCnt").get(0);
                FpGrowthBatchOp.LOG.info("minSupportCnt {}", this.minSupportCount);
            }

            public boolean filter(Tuple2<String, Integer> tuple2) throws Exception {
                return ((long) ((Integer) tuple2.f1).intValue()) >= this.minSupportCount.longValue();
            }
        }).withBroadcastSet(minSupportCnt, "minSupportCnt").name("getQualifiedItems"), "qualifiedItems");
        SingleInputUdfOperator withBroadcastSet2 = map.map(new RichMapFunction<Set<String>, int[]>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.6
            private static final long serialVersionUID = -8385780561646693463L;
            transient Map<String, Integer> tokenToId;

            public void open(Configuration configuration) throws Exception {
                this.tokenToId = new HashMap();
                getRuntimeContext().getBroadcastVariable("itemIndex").forEach(tuple2 -> {
                });
            }

            public int[] map(Set<String> set) throws Exception {
                int[] iArr = new int[set.size()];
                int i = 0;
                Iterator<String> it = set.iterator();
                while (it.hasNext()) {
                    Integer num = this.tokenToId.get(it.next());
                    if (num != null) {
                        int i2 = i;
                        i++;
                        iArr[i2] = num.intValue();
                    }
                }
                if (i <= 0) {
                    return new int[0];
                }
                int[] copyOfRange = Arrays.copyOfRange(iArr, 0, i);
                Arrays.sort(copyOfRange);
                return copyOfRange;
            }
        }).withBroadcastSet(withBroadcastSet, "itemIndex");
        JoinOperator.ProjectJoin projectFirst = reduce.join(withBroadcastSet).where(new int[]{0}).equalTo(new int[]{0}).projectSecond(new int[]{1}).projectFirst(new int[]{1});
        DataSet<Tuple2<int[], Integer>> run = new ParallelFpGrowth(this.fpTree, withBroadcastSet2, projectFirst, minSupportCnt, intValue3).run();
        DataSet<Row> patternsIndexToString = patternsIndexToString(run, withBroadcastSet);
        DataSet<Row> rulesIndexToString = rulesIndexToString(AssociationRule.extractRules(run, count, projectFirst, doubleValue2, doubleValue3, intValue2), withBroadcastSet);
        Table table = DataSetConversionUtil.toTable(getMLEnvironmentId(), patternsIndexToString, ITEMSETS_COL_NAMES, (TypeInformation<?>[]) ITEMSETS_COL_TYPES);
        Table table2 = DataSetConversionUtil.toTable(getMLEnvironmentId(), rulesIndexToString, RULES_COL_NAMES, (TypeInformation<?>[]) RULES_COL_TYPES);
        setOutputTable(table);
        setSideOutputTables(new Table[]{table2});
        return this;
    }

    public void setFpTree(FpTree fpTree) {
        this.fpTree = fpTree;
    }

    private static <T> DataSet<Long> count(DataSet<T> dataSet) {
        return dataSet.mapPartition(new MapPartitionFunction<T, Long>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.8
            private static final long serialVersionUID = 5140381060401996933L;

            public void mapPartition(Iterable<T> iterable, Collector<Long> collector) throws Exception {
                long j = 0;
                for (T t : iterable) {
                    j++;
                }
                collector.collect(Long.valueOf(j));
            }
        }).name("count_dataset").returns(Types.LONG).reduce(new ReduceFunction<Long>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.7
            private static final long serialVersionUID = -4119867135622453149L;

            public Long reduce(Long l, Long l2) throws Exception {
                return Long.valueOf(l.longValue() + l2.longValue());
            }
        });
    }

    private static DataSet<Long> getMinSupportCnt(DataSet<Long> dataSet, final int i, final double d) {
        return dataSet.map(new MapFunction<Long, Long>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.9
            private static final long serialVersionUID = 8014662380645122004L;

            public Long map(Long l) throws Exception {
                return i >= 0 ? Long.valueOf(i) : Long.valueOf((long) Math.floor(l.longValue() * d));
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String concatItems(String[] strArr, int[] iArr) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < iArr.length; i++) {
            if (i > 0) {
                sb.append(",");
            }
            sb.append(strArr[iArr[i]]);
        }
        return sb.toString();
    }

    private static DataSet<Row> patternsIndexToString(DataSet<Tuple2<int[], Integer>> dataSet, DataSet<Tuple2<String, Integer>> dataSet2) {
        return dataSet.map(new RichMapFunction<Tuple2<int[], Integer>, Row>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.10
            private static final long serialVersionUID = -6008663329770949747L;
            transient String[] itemNames;

            public void open(Configuration configuration) throws Exception {
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("itemIndex");
                this.itemNames = new String[broadcastVariable.size()];
                broadcastVariable.forEach(tuple2 -> {
                    this.itemNames[((Integer) tuple2.f1).intValue()] = (String) tuple2.f0;
                });
            }

            public Row map(Tuple2<int[], Integer> tuple2) throws Exception {
                return Row.of(new Object[]{FpGrowthBatchOp.concatItems(this.itemNames, (int[]) tuple2.f0), Long.valueOf(((Integer) tuple2.f1).longValue()), Long.valueOf(((int[]) tuple2.f0).length)});
            }
        }).withBroadcastSet(dataSet2, "itemIndex");
    }

    private static DataSet<Row> rulesIndexToString(DataSet<Tuple4<int[], int[], Integer, double[]>> dataSet, DataSet<Tuple2<String, Integer>> dataSet2) {
        return dataSet.map(new RichMapFunction<Tuple4<int[], int[], Integer, double[]>, Row>() { // from class: com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp.11
            private static final long serialVersionUID = -8669206954500536812L;
            transient String[] itemNames;

            public void open(Configuration configuration) throws Exception {
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("itemIndex");
                this.itemNames = new String[broadcastVariable.size()];
                broadcastVariable.forEach(tuple2 -> {
                    this.itemNames[((Integer) tuple2.f1).intValue()] = (String) tuple2.f0;
                });
            }

            public Row map(Tuple4<int[], int[], Integer, double[]> tuple4) throws Exception {
                return Row.of(new Object[]{FpGrowthBatchOp.concatItems(this.itemNames, (int[]) tuple4.f0) + PrefixSpanBatchOp.RULE_SEPARATOR + FpGrowthBatchOp.concatItems(this.itemNames, (int[]) tuple4.f1), Long.valueOf(((int[]) tuple4.f0).length + ((int[]) tuple4.f1).length), Double.valueOf(((double[]) tuple4.f3)[0]), Double.valueOf(((double[]) tuple4.f3)[1]), Double.valueOf(((double[]) tuple4.f3)[2]), Long.valueOf(((Integer) tuple4.f2).longValue())});
            }
        }).withBroadcastSet(dataSet2, "itemIndex");
    }

    public BatchOperator<?> getSideOutputAssociationRules() {
        return getSideOutput(0);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ FpGrowthBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
