package com.alibaba.alink.operator.common.dataproc;

import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.params.dataproc.HasStringOrderTypeDefaultAsRandom;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/StringIndexerUtil.class */
public class StringIndexerUtil {
    public static DataSet<Tuple3<Integer, String, Long>> indexTokens(DataSet<Row> dataSet, HasStringOrderTypeDefaultAsRandom.StringOrderType stringOrderType, long j, boolean z) {
        switch (stringOrderType) {
            case RANDOM:
                return indexRandom(dataSet, j, z);
            case FREQUENCY_ASC:
                return indexSortedByFreq(dataSet, j, z, true);
            case FREQUENCY_DESC:
                return indexSortedByFreq(dataSet, j, z, false);
            case ALPHABET_ASC:
                return indexSortedByAlphabet(dataSet, j, z, true);
            case ALPHABET_DESC:
                return indexSortedByAlphabet(dataSet, j, z, false);
            default:
                throw new AkUnsupportedOperationException("Unsupported order type " + stringOrderType);
        }
    }

    public static DataSet<Tuple3<Integer, String, Long>> indexRandom(DataSet<Row> dataSet, final long j, boolean z) {
        DataSet<Tuple2<Integer, String>> flattenTokens = flattenTokens(dataSet, z);
        return zipWithIndexPerColumn(z ? flattenTokens.groupBy(new int[]{0, 1}).reduce(new ReduceFunction<Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.1
            private static final long serialVersionUID = 3246078624056103227L;

            public Tuple2<Integer, String> reduce(Tuple2<Integer, String> tuple2, Tuple2<Integer, String> tuple22) throws Exception {
                return tuple2;
            }
        }).name("distinct_tokens") : flattenTokens.groupBy(new int[]{0}).reduceGroup(new GroupReduceFunction<Tuple2<Integer, String>, Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.2
            public void reduce(Iterable<Tuple2<Integer, String>> iterable, Collector<Tuple2<Integer, String>> collector) throws Exception {
                boolean z2 = false;
                HashSet hashSet = new HashSet();
                for (Tuple2<Integer, String> tuple2 : iterable) {
                    if (tuple2.f1 == null) {
                        if (!z2) {
                            z2 = true;
                            collector.collect(tuple2);
                        }
                    } else if (!hashSet.contains(tuple2.f1)) {
                        collector.collect(tuple2);
                        hashSet.add(tuple2.f1);
                    }
                }
            }
        }).name("distinct_tokens")).map(new MapFunction<Tuple3<Long, Integer, String>, Tuple3<Integer, String, Long>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.3
            private static final long serialVersionUID = 5009352124958484740L;

            public Tuple3<Integer, String, Long> map(Tuple3<Long, Integer, String> tuple3) throws Exception {
                return Tuple3.of(tuple3.f1, tuple3.f2, Long.valueOf(((Long) tuple3.f0).longValue() + j));
            }
        }).name("assign_index");
    }

    public static DataSet<Tuple3<Integer, String, Long>> indexSortedByFreq(DataSet<Row> dataSet, final long j, boolean z, boolean z2) {
        return countTokens(dataSet, z).groupBy(new int[]{0}).sortGroup(2, z2 ? Order.ASCENDING : Order.DESCENDING).reduceGroup(new GroupReduceFunction<Tuple3<Integer, String, Long>, Tuple3<Integer, String, Long>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.4
            private static final long serialVersionUID = 3454314323952925197L;

            public void reduce(Iterable<Tuple3<Integer, String, Long>> iterable, Collector<Tuple3<Integer, String, Long>> collector) throws Exception {
                long j2 = j;
                for (Tuple3<Integer, String, Long> tuple3 : iterable) {
                    long j3 = j2;
                    j2 = j3 + 1;
                    collector.collect(Tuple3.of(tuple3.f0, tuple3.f1, Long.valueOf(j3)));
                }
            }
        });
    }

    public static DataSet<Tuple3<Integer, String, Long>> distinct(DataSet<Row> dataSet, final long j, boolean z) {
        DataSet<Tuple2<Integer, String>> flattenTokens = flattenTokens(dataSet, z);
        return zipWithIndexPerColumn(z ? flattenTokens.groupBy(new int[]{0, 1}).reduce(new ReduceFunction<Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.5
            private static final long serialVersionUID = 3246078624056103227L;

            public Tuple2<Integer, String> reduce(Tuple2<Integer, String> tuple2, Tuple2<Integer, String> tuple22) throws Exception {
                return tuple2;
            }
        }).name("distinct_tokens") : flattenTokens.groupBy(new int[]{0}).reduceGroup(new GroupReduceFunction<Tuple2<Integer, String>, Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.6
            public void reduce(Iterable<Tuple2<Integer, String>> iterable, Collector<Tuple2<Integer, String>> collector) throws Exception {
                boolean z2 = false;
                HashSet hashSet = new HashSet();
                for (Tuple2<Integer, String> tuple2 : iterable) {
                    if (tuple2.f1 == null) {
                        if (!z2) {
                            z2 = true;
                            collector.collect(tuple2);
                        }
                    } else if (!hashSet.contains(tuple2.f1)) {
                        collector.collect(tuple2);
                        hashSet.add(tuple2.f1);
                    }
                }
            }
        }).name("distinct_tokens")).map(new MapFunction<Tuple3<Long, Integer, String>, Tuple3<Integer, String, Long>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.7
            private static final long serialVersionUID = 5009352124958484740L;

            public Tuple3<Integer, String, Long> map(Tuple3<Long, Integer, String> tuple3) throws Exception {
                return Tuple3.of(tuple3.f1, tuple3.f2, Long.valueOf(((Long) tuple3.f0).longValue() + j));
            }
        }).name("assign_index");
    }

    public static DataSet<Tuple3<Integer, String, Long>> countTokens(DataSet<Row> dataSet, boolean z) {
        return flattenTokens(dataSet, z).map(new MapFunction<Tuple2<Integer, String>, Tuple3<Integer, String, Long>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.9
            private static final long serialVersionUID = -3620557146524640793L;

            public Tuple3<Integer, String, Long> map(Tuple2<Integer, String> tuple2) throws Exception {
                return Tuple3.of(tuple2.f0, tuple2.f1, 1L);
            }
        }).groupBy(new int[]{0}).reduceGroup(new GroupReduceFunction<Tuple3<Integer, String, Long>, Tuple3<Integer, String, Long>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.8
            public void reduce(Iterable<Tuple3<Integer, String, Long>> iterable, Collector<Tuple3<Integer, String, Long>> collector) throws Exception {
                int i = -1;
                long j = 0;
                HashMap hashMap = new HashMap();
                for (Tuple3<Integer, String, Long> tuple3 : iterable) {
                    if (i == -1) {
                        i = ((Integer) tuple3.f0).intValue();
                    }
                    if (tuple3.f1 == null) {
                        j += ((Long) tuple3.f2).longValue();
                    } else {
                        hashMap.put(tuple3.f1, Long.valueOf(((Long) tuple3.f2).longValue() + ((Long) hashMap.getOrDefault(tuple3.f1, 0L)).longValue()));
                    }
                }
                if (j != 0) {
                    collector.collect(Tuple3.of(Integer.valueOf(i), (Object) null, Long.valueOf(j)));
                }
                for (Map.Entry entry : hashMap.entrySet()) {
                    collector.collect(Tuple3.of(Integer.valueOf(i), entry.getKey(), entry.getValue()));
                }
            }
        }).name("count_tokens");
    }

    public static DataSet<Tuple3<Integer, String, Long>> indexSortedByAlphabet(DataSet<Row> dataSet, final long j, boolean z, final boolean z2) {
        return flattenTokens(dataSet, z).groupBy(new int[]{0, 1}).reduce(new ReduceFunction<Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.10
            private static final long serialVersionUID = 8562489277234571790L;

            public Tuple2<Integer, String> reduce(Tuple2<Integer, String> tuple2, Tuple2<Integer, String> tuple22) throws Exception {
                return tuple2;
            }
        }).name("distinct_tokens").groupBy(new int[]{0}).reduceGroup(new RichGroupReduceFunction<Tuple2<Integer, String>, Tuple3<Integer, String, Long>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.11
            private static final long serialVersionUID = -5673388400144888098L;

            public void reduce(Iterable<Tuple2<Integer, String>> iterable, Collector<Tuple3<Integer, String, Long>> collector) throws Exception {
                int i = -1;
                ArrayList arrayList = new ArrayList();
                for (Tuple2<Integer, String> tuple2 : iterable) {
                    i = ((Integer) tuple2.f0).intValue();
                    arrayList.add(tuple2.f1);
                }
                if (z2) {
                    arrayList.sort(Comparator.nullsFirst(Comparator.naturalOrder()));
                } else {
                    arrayList.sort(Comparator.nullsFirst(Comparator.reverseOrder()));
                }
                for (int i2 = 0; i2 < arrayList.size(); i2++) {
                    collector.collect(Tuple3.of(Integer.valueOf(i), arrayList.get(i2), Long.valueOf(j + i2)));
                }
            }
        }).name("assign_index");
    }

    private static DataSet<Tuple2<Integer, String>> flattenTokens(DataSet<Row> dataSet, final boolean z) {
        return dataSet.flatMap(new FlatMapFunction<Row, Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.12
            private static final long serialVersionUID = 4865017597670627434L;

            public void flatMap(Row row, Collector<Tuple2<Integer, String>> collector) throws Exception {
                for (int i = 0; i < row.getArity(); i++) {
                    Object field = row.getField(i);
                    if (field != null) {
                        collector.collect(Tuple2.of(Integer.valueOf(i), String.valueOf(field)));
                    } else if (!z) {
                        collector.collect(Tuple2.of(Integer.valueOf(i), (Object) null));
                    }
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Tuple2<Integer, String>>) collector);
            }
        }).name("flatten_tokens");
    }

    public static DataSet<Tuple3<Long, Integer, String>> zipWithIndexPerColumn(DataSet<Tuple2<Integer, String>> dataSet) {
        return dataSet.groupBy(new int[]{0}).reduceGroup(new GroupReduceFunction<Tuple2<Integer, String>, Tuple3<Long, Integer, String>>() { // from class: com.alibaba.alink.operator.common.dataproc.StringIndexerUtil.13
            private static final long serialVersionUID = 1297859189970595767L;

            public void reduce(Iterable<Tuple2<Integer, String>> iterable, Collector<Tuple3<Long, Integer, String>> collector) throws Exception {
                long j = 0;
                for (Tuple2<Integer, String> tuple2 : iterable) {
                    collector.collect(Tuple3.of(Long.valueOf(j), tuple2.f0, tuple2.f1));
                    j++;
                }
            }
        });
    }
}
