package com.alibaba.alink.operator.batch.associationrule;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.associationrule.ParallelPrefixSpan;
import com.alibaba.alink.operator.common.associationrule.SequenceRule;
import com.alibaba.alink.params.associationrule.PrefixSpanParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFilterFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.operators.AggregateOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpec(name = "itemsCol", allowedTypeCollections = {TypeCollections.STRING_TYPES})
@NameCn("PrefixSpan")
@NameEn("PrefixSpan")
/* loaded from: input_file:com/alibaba/alink/operator/batch/associationrule/PrefixSpanBatchOp.class */
public final class PrefixSpanBatchOp extends BatchOperator<PrefixSpanBatchOp> implements PrefixSpanParams<PrefixSpanBatchOp> {
    public static final String ITEM_SEPARATOR = ",";
    public static final String ELEMENT_SEPARATOR = ";";
    public static final String RULE_SEPARATOR = "=>";
    private static final long serialVersionUID = 359216485925896796L;
    private static final Logger LOG = LoggerFactory.getLogger(PrefixSpanBatchOp.class);
    private static final String[] ITEMSETS_COL_NAMES = {"itemset", "supportcount", "itemcount"};
    private static final String[] RULES_COL_NAMES = {"rule", "chain_length", "support", "confidence", "transaction_count"};
    private static final TypeInformation[] ITEMSETS_COL_TYPES = {Types.STRING, Types.LONG, Types.LONG};
    private static final TypeInformation[] RULES_COL_TYPES = {Types.STRING, Types.LONG, Types.DOUBLE, Types.DOUBLE, Types.LONG};

    public PrefixSpanBatchOp() {
        this(new Params());
    }

    public PrefixSpanBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public PrefixSpanBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String itemsCol = getItemsCol();
        double doubleValue = getMinSupportPercent().doubleValue();
        int intValue = getMinSupportCount().intValue();
        int intValue2 = getMaxPatternLength().intValue();
        double doubleValue2 = getMinConfidence().doubleValue();
        final int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getSchema(), itemsCol);
        DataSet<Long> count = count(checkAndGetFirst.getDataSet());
        DataSet<Long> minSupportCnt = getMinSupportCnt(count, intValue, doubleValue);
        Operator name = checkAndGetFirst.getDataSet().map(new MapFunction<Row, List<List<String>>>() { // from class: com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp.1
            private static final long serialVersionUID = -61898178381287582L;

            public List<List<String>> map(Row row) throws Exception {
                String str = (String) row.getField(findColIndexWithAssertAndHint);
                if (StringUtils.isNullOrWhitespaceOnly(str)) {
                    return new ArrayList();
                }
                String[] split = str.split(PrefixSpanBatchOp.ELEMENT_SEPARATOR);
                ArrayList arrayList = new ArrayList(split.length);
                for (String str2 : split) {
                    arrayList.add(Arrays.asList(str2.trim().split(",")));
                }
                return arrayList;
            }
        }).name("split_sequences");
        AggregateOperator aggregate = name.flatMap(new FlatMapFunction<List<List<String>>, Tuple2<String, Integer>>() { // from class: com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp.2
            private static final long serialVersionUID = -8797680399192984088L;

            public void flatMap(List<List<String>> list, Collector<Tuple2<String, Integer>> collector) throws Exception {
                list.forEach(list2 -> {
                    list2.forEach(str -> {
                        collector.collect(Tuple2.of(str, 1));
                    });
                });
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((List<List<String>>) obj, (Collector<Tuple2<String, Integer>>) collector);
            }
        }).groupBy(new int[]{0}).aggregate(Aggregations.SUM, 1);
        SingleInputUdfOperator withBroadcastSet = checkAndGetFirst.getDataSet().getExecutionEnvironment().fromElements(new Integer[]{0}).flatMap(new RichFlatMapFunction<Integer, Tuple2<String, Integer>>() { // from class: com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp.4
            private static final long serialVersionUID = 6034744231319348373L;

            public void flatMap(Integer num, Collector<Tuple2<String, Integer>> collector) throws Exception {
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("qualifiedItems");
                Integer[] numArr = new Integer[broadcastVariable.size()];
                for (int i = 0; i < numArr.length; i++) {
                    numArr[i] = Integer.valueOf(i);
                }
                Arrays.sort(numArr, (num2, num3) -> {
                    Integer num2 = (Integer) ((Tuple2) broadcastVariable.get(num2.intValue())).f1;
                    Integer num3 = (Integer) ((Tuple2) broadcastVariable.get(num3.intValue())).f1;
                    return num2.equals(num3) ? ((String) ((Tuple2) broadcastVariable.get(num2.intValue())).f0).compareTo((String) ((Tuple2) broadcastVariable.get(num3.intValue())).f0) : Integer.compare(num3.intValue(), num2.intValue());
                });
                for (int i2 = 0; i2 < numArr.length; i2++) {
                    collector.collect(Tuple2.of(((Tuple2) broadcastVariable.get(numArr[i2].intValue())).f0, Integer.valueOf(i2 + 1)));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Integer) obj, (Collector<Tuple2<String, Integer>>) collector);
            }
        }).withBroadcastSet(aggregate.filter(new RichFilterFunction<Tuple2<String, Integer>>() { // from class: com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp.3
            private static final long serialVersionUID = -6389675086514287621L;
            transient Long minSupportCount;

            public void open(Configuration configuration) throws Exception {
                this.minSupportCount = (Long) getRuntimeContext().getBroadcastVariable("minSupportCnt").get(0);
                PrefixSpanBatchOp.LOG.info("minSupportCnt {}", this.minSupportCount);
            }

            public boolean filter(Tuple2<String, Integer> tuple2) throws Exception {
                return ((long) ((Integer) tuple2.f1).intValue()) >= this.minSupportCount.longValue();
            }
        }).withBroadcastSet(minSupportCnt, "minSupportCnt").name("get_qualified_items"), "qualifiedItems");
        DataSet<Tuple2<int[], Integer>> run = new ParallelPrefixSpan(name.map(new RichMapFunction<List<List<String>>, int[]>() { // from class: com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp.5
            private static final long serialVersionUID = -9049976553096196637L;
            transient Map<String, Integer> tokenToId;

            public void open(Configuration configuration) throws Exception {
                this.tokenToId = new HashMap();
                getRuntimeContext().getBroadcastVariable("itemIndex").forEach(tuple2 -> {
                });
            }

            public int[] map(List<List<String>> list) throws Exception {
                ArrayList arrayList = new ArrayList();
                arrayList.add(0);
                Iterator<List<String>> it = list.iterator();
                while (it.hasNext()) {
                    int i = 0;
                    Iterator<String> it2 = it.next().iterator();
                    while (it2.hasNext()) {
                        Integer num = this.tokenToId.get(it2.next());
                        if (num != null) {
                            i++;
                            arrayList.add(num);
                        }
                    }
                    if (i > 0) {
                        arrayList.add(0);
                    }
                }
                int[] iArr = new int[arrayList.size()];
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    iArr[i2] = ((Integer) arrayList.get(i2)).intValue();
                }
                return iArr;
            }
        }).withBroadcastSet(withBroadcastSet, "itemIndex").name("map_seq_to_int_array"), minSupportCnt, aggregate.join(withBroadcastSet).where(new int[]{0}).equalTo(new int[]{0}).projectSecond(new int[]{1}).projectFirst(new int[]{1}), intValue2).run();
        DataSet<Tuple4<int[], int[], Integer, double[]>> extractSequenceRules = SequenceRule.extractSequenceRules(run, count, doubleValue2);
        DataSet<Row> patternsIndexToString = patternsIndexToString(run, withBroadcastSet);
        DataSet<Row> rulesIndexToString = rulesIndexToString(extractSequenceRules, withBroadcastSet);
        Table table = DataSetConversionUtil.toTable(getMLEnvironmentId(), patternsIndexToString, ITEMSETS_COL_NAMES, (TypeInformation<?>[]) ITEMSETS_COL_TYPES);
        Table table2 = DataSetConversionUtil.toTable(getMLEnvironmentId(), rulesIndexToString, RULES_COL_NAMES, (TypeInformation<?>[]) RULES_COL_TYPES);
        setOutputTable(table);
        setSideOutputTables(new Table[]{table2});
        return this;
    }

    private static <T> DataSet<Long> count(DataSet<T> dataSet) {
        return dataSet.mapPartition(new MapPartitionFunction<T, Long>() { // from class: com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp.7
            private static final long serialVersionUID = -7330867302047392129L;

            public void mapPartition(Iterable<T> iterable, Collector<Long> collector) throws Exception {
                long j = 0;
                for (T t : iterable) {
                    j++;
                }
                collector.collect(Long.valueOf(j));
            }
        }).name("count_dataset").returns(Types.LONG).reduce(new ReduceFunction<Long>() { // from class: com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp.6
            private static final long serialVersionUID = 3246426894892395344L;

            public Long reduce(Long l, Long l2) throws Exception {
                return Long.valueOf(l.longValue() + l2.longValue());
            }
        });
    }

    private static DataSet<Long> getMinSupportCnt(DataSet<Long> dataSet, final int i, final double d) {
        return dataSet.map(new MapFunction<Long, Long>() { // from class: com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp.8
            private static final long serialVersionUID = 3819221532061581540L;

            public Long map(Long l) throws Exception {
                return i >= 0 ? Long.valueOf(i) : Long.valueOf((long) Math.floor(l.longValue() * d));
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Tuple3<String, Long, Long> encodeSequence(int[] iArr, String[] strArr) {
        StringBuilder sb = new StringBuilder();
        int i = 0;
        long j = 1;
        long j2 = 0;
        for (int i2 = 1; i2 < iArr.length - 1; i2++) {
            if (iArr[i2] == 0) {
                sb.append(ELEMENT_SEPARATOR);
                j++;
                i = 0;
            } else {
                if (i > 0) {
                    sb.append(",");
                }
                sb.append(strArr[iArr[i2]]);
                i++;
                j2++;
            }
        }
        return Tuple3.of(sb.toString(), Long.valueOf(j2), Long.valueOf(j));
    }

    private static DataSet<Row> patternsIndexToString(DataSet<Tuple2<int[], Integer>> dataSet, DataSet<Tuple2<String, Integer>> dataSet2) {
        return dataSet.map(new RichMapFunction<Tuple2<int[], Integer>, Row>() { // from class: com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp.9
            private static final long serialVersionUID = -688658684297137986L;
            transient String[] itemNames;

            public void open(Configuration configuration) throws Exception {
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("itemIndex");
                this.itemNames = new String[broadcastVariable.size() + 1];
                broadcastVariable.forEach(tuple2 -> {
                    this.itemNames[((Integer) tuple2.f1).intValue()] = (String) tuple2.f0;
                });
            }

            public Row map(Tuple2<int[], Integer> tuple2) throws Exception {
                Tuple3 encodeSequence = PrefixSpanBatchOp.encodeSequence((int[]) tuple2.f0, this.itemNames);
                return Row.of(new Object[]{encodeSequence.f0, Long.valueOf(((Integer) tuple2.f1).longValue()), encodeSequence.f1});
            }
        }).withBroadcastSet(dataSet2, "itemIndex").name("patternsIndexToString");
    }

    private static DataSet<Row> rulesIndexToString(DataSet<Tuple4<int[], int[], Integer, double[]>> dataSet, DataSet<Tuple2<String, Integer>> dataSet2) {
        return dataSet.map(new RichMapFunction<Tuple4<int[], int[], Integer, double[]>, Row>() { // from class: com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp.10
            private static final long serialVersionUID = 7821061547668068480L;
            transient String[] itemNames;

            public void open(Configuration configuration) throws Exception {
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("itemIndex");
                this.itemNames = new String[broadcastVariable.size() + 1];
                broadcastVariable.forEach(tuple2 -> {
                    this.itemNames[((Integer) tuple2.f1).intValue()] = (String) tuple2.f0;
                });
            }

            public Row map(Tuple4<int[], int[], Integer, double[]> tuple4) throws Exception {
                Tuple3 encodeSequence = PrefixSpanBatchOp.encodeSequence((int[]) tuple4.f0, this.itemNames);
                Tuple3 encodeSequence2 = PrefixSpanBatchOp.encodeSequence((int[]) tuple4.f1, this.itemNames);
                return Row.of(new Object[]{((String) encodeSequence.f0) + PrefixSpanBatchOp.RULE_SEPARATOR + ((String) encodeSequence2.f0), Long.valueOf(((Long) encodeSequence.f2).longValue() + ((Long) encodeSequence2.f2).longValue()), Double.valueOf(((double[]) tuple4.f3)[0]), Double.valueOf(((double[]) tuple4.f3)[1]), Long.valueOf(((Integer) tuple4.f2).longValue())});
            }
        }).withBroadcastSet(dataSet2, "itemIndex").name("rulesIndexToString");
    }

    public BatchOperator<?> getSideOutputAssociationRules() {
        return getSideOutput(0);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ PrefixSpanBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
