package com.alibaba.alink.operator.common.nlp;

import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.RowUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.TableSourceBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.dataproc.SortUtils;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/nlp/WordCountUtil.class */
public class WordCountUtil {
    public static final String WORD_COL_NAME = "word";
    public static final String COUNT_COL_NAME = "cnt";
    public static final String INDEX_COL_NAME = "idx";
    public static final int BOUND_SIZE = 10000;
    private static final Logger LOG = LoggerFactory.getLogger(WordCountUtil.class);

    /* loaded from: input_file:com/alibaba/alink/operator/common/nlp/WordCountUtil$GenContentMapping.class */
    public static final class GenContentMapping extends RichMapFunction<Row, Row> {
        private static final long serialVersionUID = -2148000211178502657L;
        private final boolean isWord;
        private final String wordDelimiter;
        private final int transColSize;
        private final int keepColSize;
        private final int outputColSize;
        private final int[] colIdxs;
        private final int[] appendColIdxs;
        private Map<String, Long> vocMap;

        public GenContentMapping(int[] iArr, int[] iArr2, boolean z, String str) {
            this.colIdxs = iArr;
            this.appendColIdxs = iArr2;
            this.isWord = z;
            this.wordDelimiter = str;
            this.transColSize = this.colIdxs.length;
            this.keepColSize = this.appendColIdxs == null ? 0 : this.appendColIdxs.length;
            this.outputColSize = this.transColSize + this.keepColSize;
        }

        public void open(Configuration configuration) throws Exception {
            List<Row> broadcastVariable = getRuntimeContext().getBroadcastVariable("vocabulary");
            this.vocMap = new HashMap();
            for (Row row : broadcastVariable) {
                this.vocMap.put((String) row.getField(0), (Long) row.getField(1));
            }
        }

        public Row map(Row row) throws Exception {
            Row row2 = new Row(this.outputColSize);
            for (int i = 0; i < this.transColSize; i++) {
                if (this.isWord) {
                    row2.setField(i, this.vocMap.getOrDefault((String) row.getField(this.colIdxs[i]), null));
                } else {
                    String[] split = ((String) row.getField(this.colIdxs[i])).split(this.wordDelimiter);
                    StringBuilder sb = new StringBuilder();
                    int i2 = 0;
                    for (String str : split) {
                        if (this.vocMap.containsKey(str)) {
                            if (i2 > 0) {
                                sb.append(",");
                            }
                            sb.append(this.vocMap.get(str).toString());
                            i2++;
                        }
                    }
                    row2.setField(i, sb.toString());
                }
            }
            for (int i3 = 0; i3 < this.keepColSize; i3++) {
                row2.setField(this.transColSize + i3, row.getField(this.appendColIdxs[i3]));
            }
            return row2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/nlp/WordCountUtil$RandomIndexMapper.class */
    public static class RandomIndexMapper implements MapFunction<Tuple2<Long, Row>, Row> {
        private static final long serialVersionUID = -7478698242309473099L;
        private final long startIndex;

        public RandomIndexMapper(long j) {
            this.startIndex = j;
        }

        public Row map(Tuple2<Long, Row> tuple2) throws Exception {
            return RowUtil.merge((Row) tuple2.f1, Long.valueOf(((Long) tuple2.f0).longValue() + this.startIndex));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/nlp/WordCountUtil$WordSpliter.class */
    public static class WordSpliter implements FlatMapFunction<Row, String[]> {
        private static final long serialVersionUID = -699577713738103461L;
        private final String wordDelimiter;

        public WordSpliter(String str) {
            this.wordDelimiter = str;
        }

        public void flatMap(Row row, Collector<String[]> collector) throws Exception {
            if (null != row.getField(0)) {
                collector.collect(((String) row.getField(0)).split(this.wordDelimiter));
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<String[]>) collector);
        }
    }

    public static BatchOperator<?> splitDocAndCount(BatchOperator<?> batchOperator, String str, String str2) {
        return count(splitDoc(batchOperator, str, str2), "word", COUNT_COL_NAME);
    }

    public static BatchOperator<?> splitDoc(BatchOperator<?> batchOperator, String str, String str2) {
        return batchOperator.udtf(str, new String[]{"word", COUNT_COL_NAME}, new DocWordSplitCount(str2), new String[0]);
    }

    public static BatchOperator<?> count(BatchOperator<?> batchOperator, String str) {
        return count(batchOperator, str, null);
    }

    public static BatchOperator<?> count(BatchOperator<?> batchOperator, String str, String str2) {
        return null == str2 ? batchOperator.groupBy(str, str + " AS word, COUNT(" + str + ") AS " + COUNT_COL_NAME) : batchOperator.groupBy(str, str + " AS word, SUM(" + str2 + ") AS " + COUNT_COL_NAME);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static BatchOperator<?> randomIndexVocab(BatchOperator<?> batchOperator, long j) {
        TableSchema schema = batchOperator.getSchema();
        return (BatchOperator) new TableSourceBatchOp(DataSetConversionUtil.toTable(batchOperator.getMLEnvironmentId(), randomIndexVocab(batchOperator.getDataSet(), j), (String[]) ArrayUtils.add(schema.getFieldNames(), INDEX_COL_NAME), (TypeInformation<?>[]) ArrayUtils.add(schema.getFieldTypes(), Types.LONG))).setMLEnvironmentId(batchOperator.getMLEnvironmentId());
    }

    public static DataSet<Row> randomIndexVocab(DataSet<Row> dataSet, long j) {
        return DataSetUtils.zipWithIndex(dataSet).map(new RandomIndexMapper(j));
    }

    public static Tuple3<DataSet<Row>, DataSet<Long[]>, DataSet<long[]>> sortedIndexVocab(DataSet<Row> dataSet, long j) {
        return sortedIndexVocab(dataSet, j, false);
    }

    public static Tuple3<DataSet<Row>, DataSet<Long[]>, DataSet<long[]>> sortedIndexVocab(DataSet<Row> dataSet, final long j, final boolean z) {
        GroupReduceOperator groupReduceOperator = null;
        if (z) {
            groupReduceOperator = dataSet.groupBy(new int[]{3}).reduceGroup(new GroupReduceFunction<Row, Object>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.2
                private static final long serialVersionUID = -4853809045012459178L;

                public void reduce(Iterable<Row> iterable, Collector<Object> collector) throws Exception {
                    Object obj = null;
                    Iterator<Row> it = iterable.iterator();
                    while (it.hasNext()) {
                        obj = it.next().getField(3);
                    }
                    collector.collect(obj);
                }
            }).reduceGroup(new GroupReduceFunction<Object, Tuple2<Object, Integer>>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.1
                private static final long serialVersionUID = -58559875101303985L;

                public void reduce(Iterable<Object> iterable, Collector<Tuple2<Object, Integer>> collector) throws Exception {
                    int i = 0;
                    Iterator<Object> it = iterable.iterator();
                    while (it.hasNext()) {
                        int i2 = i;
                        i++;
                        collector.collect(Tuple2.of(it.next(), Integer.valueOf(i2)));
                    }
                }
            });
        }
        PartitionOperator partitionCustom = ((DataSet) SortUtils.pSort(dataSet, 2).f0).partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.3
            private static final long serialVersionUID = -533045561688945931L;

            public int partition(Integer num, int i) {
                return num.intValue();
            }
        }, 0);
        MapPartitionOperator mapPartition = partitionCustom.mapPartition(new RichMapPartitionFunction<Tuple2<Integer, Row>, Tuple2<Integer, long[]>>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.4
            private static final long serialVersionUID = -7621048604300571469L;
            int instId;
            final Map<Object, Integer> typeMap = new HashMap();

            /* JADX WARN: Multi-variable type inference failed */
            public void open(Configuration configuration) throws Exception {
                this.instId = getRuntimeContext().getIndexOfThisSubtask();
                if (z) {
                    for (Tuple2 tuple2 : getRuntimeContext().getBroadcastVariable("w2vGroupTypes")) {
                        this.typeMap.put(tuple2.f0, tuple2.f1);
                    }
                }
            }

            public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple2<Integer, long[]>> collector) {
                long[] jArr = z ? new long[this.typeMap.keySet().size()] : new long[1];
                for (Tuple2<Integer, Row> tuple2 : iterable) {
                    if (z) {
                        long[] jArr2 = jArr;
                        int intValue = this.typeMap.get(((Row) tuple2.f1).getField(3)).intValue();
                        jArr2[intValue] = jArr2[intValue] + 1;
                    } else {
                        long[] jArr3 = jArr;
                        jArr3[0] = jArr3[0] + 1;
                    }
                }
                collector.collect(Tuple2.of(Integer.valueOf(this.instId), jArr));
            }
        });
        if (z) {
            mapPartition = mapPartition.withBroadcastSet(groupReduceOperator, "w2vGroupTypes");
        }
        MapPartitionOperator mapPartition2 = partitionCustom.mapPartition(new RichMapPartitionFunction<Tuple2<Integer, Row>, Tuple2<Integer, double[]>>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.5
            private static final long serialVersionUID = 3802056745076337899L;
            int instId;
            final Map<Object, Integer> typeMap = new HashMap();

            /* JADX WARN: Multi-variable type inference failed */
            public void open(Configuration configuration) throws Exception {
                this.instId = getRuntimeContext().getIndexOfThisSubtask();
                if (z) {
                    for (Tuple2 tuple2 : getRuntimeContext().getBroadcastVariable("w2vGroupTypes")) {
                        this.typeMap.put(tuple2.f0, tuple2.f1);
                    }
                }
            }

            public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple2<Integer, double[]>> collector) {
                double[] dArr = z ? new double[this.typeMap.keySet().size()] : new double[1];
                for (Tuple2<Integer, Row> tuple2 : iterable) {
                    if (z) {
                        double[] dArr2 = dArr;
                        int intValue = this.typeMap.get(((Row) tuple2.f1).getField(3)).intValue();
                        dArr2[intValue] = dArr2[intValue] + ((Double) ((Row) tuple2.f1).getField(2)).doubleValue();
                    } else {
                        double[] dArr3 = dArr;
                        dArr3[0] = dArr3[0] + ((Double) ((Row) tuple2.f1).getField(2)).doubleValue();
                    }
                }
                collector.collect(new Tuple2(Integer.valueOf(this.instId), dArr));
            }
        });
        if (z) {
            mapPartition2 = mapPartition2.withBroadcastSet(groupReduceOperator, "w2vGroupTypes");
        }
        MapPartitionOperator withBroadcastSet = partitionCustom.mapPartition(new RichMapPartitionFunction<Tuple2<Integer, Row>, Row>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.6
            private static final long serialVersionUID = 1426278870328664711L;
            int size;
            long[] startIdx;
            long[] totalCountIdx;
            double[] weightStart;
            double[] weightTotal;
            double[] curWeightTotal;
            boolean isFirstPartition;
            final Map<Object, Integer> typeMap = new HashMap();

            /* JADX WARN: Multi-variable type inference failed */
            public void open(Configuration configuration) throws Exception {
                this.size = 1;
                if (z) {
                    for (Tuple2 tuple2 : getRuntimeContext().getBroadcastVariable("w2vGroupTypes")) {
                        this.typeMap.put(tuple2.f0, tuple2.f1);
                    }
                    this.size = this.typeMap.size();
                }
                List<Tuple2> broadcastVariable = getRuntimeContext().getBroadcastVariable("w2vInstVocabSize");
                WordCountUtil.LOG.info("w2vInstVocabSize: {}", JsonConverter.gson.toJson(broadcastVariable));
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                this.isFirstPartition = 0 == indexOfThisSubtask;
                this.startIdx = new long[this.size];
                this.totalCountIdx = new long[this.size];
                long[] jArr = new long[this.size];
                for (Tuple2 tuple22 : broadcastVariable) {
                    for (int i = 0; i < this.size; i++) {
                        if (((Integer) tuple22.f0).intValue() < indexOfThisSubtask) {
                            long[] jArr2 = this.startIdx;
                            int i2 = i;
                            jArr2[i2] = jArr2[i2] + ((long[]) tuple22.f1)[i];
                        }
                        long[] jArr3 = this.totalCountIdx;
                        int i3 = i;
                        jArr3[i3] = jArr3[i3] + ((long[]) tuple22.f1)[i];
                        if (i == 0) {
                            jArr[i] = this.totalCountIdx[i];
                        } else {
                            jArr[i] = jArr[i - 1] + this.totalCountIdx[i];
                        }
                    }
                }
                for (int i4 = 1; i4 < this.size; i4++) {
                    long[] jArr4 = this.startIdx;
                    int i5 = i4;
                    jArr4[i5] = jArr4[i5] + jArr[i4 - 1];
                }
                List<Tuple2> broadcastVariable2 = getRuntimeContext().getBroadcastVariable("w2vWeightInstSum");
                this.weightStart = new double[this.size];
                this.weightTotal = new double[this.size];
                this.curWeightTotal = new double[this.size];
                for (Tuple2 tuple23 : broadcastVariable2) {
                    for (int i6 = 0; i6 < this.size; i6++) {
                        if (((Integer) tuple23.f0).intValue() < indexOfThisSubtask) {
                            double[] dArr = this.weightStart;
                            int i7 = i6;
                            dArr[i7] = dArr[i7] + ((double[]) tuple23.f1)[i6];
                        }
                        double[] dArr2 = this.weightTotal;
                        int i8 = i6;
                        dArr2[i8] = dArr2[i8] + ((double[]) tuple23.f1)[i6];
                        if (((Integer) tuple23.f0).intValue() == indexOfThisSubtask) {
                            this.curWeightTotal[i6] = ((double[]) tuple23.f1)[i6];
                        }
                    }
                }
            }

            public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Row> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Iterator<Tuple2<Integer, Row>> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next().f1);
                }
                arrayList.sort(new SortUtils.RowComparator(2));
                long[] jArr = new long[this.size];
                for (int i = 0; i < this.size; i++) {
                    jArr[i] = j;
                }
                double[] dArr = new double[this.size];
                if (this.isFirstPartition) {
                    for (int i2 = 0; i2 < this.size; i2++) {
                        collector.collect(Row.of(new Object[]{null, Long.valueOf(this.startIdx[i2] + jArr[i2]), null, Long.valueOf(-((10001 * i2) + 1))}));
                        collector.collect(Row.of(new Object[]{null, Long.valueOf((this.totalCountIdx[i2] - 1) + jArr[i2]), null, Long.valueOf(-(10001 * (i2 + 1)))}));
                    }
                }
                long[] jArr2 = new long[this.size];
                for (int i3 = 0; i3 < this.size; i3++) {
                    jArr2[i3] = (long) Math.floor((this.weightStart[i3] / this.weightTotal[i3]) * 10000.0d);
                    if ((this.weightStart[i3] / this.weightTotal[i3]) * 10000.0d <= jArr2[i3] && this.curWeightTotal[i3] > Criteria.INVALID_GAIN) {
                        collector.collect(Row.of(new Object[]{null, Long.valueOf(this.startIdx[i3] + jArr[i3]), null, Long.valueOf(-(jArr2[i3] + (10001 * i3) + 1))}));
                    }
                }
                int i4 = 0;
                Iterator it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    Row row = (Row) it2.next();
                    if (z) {
                        i4 = this.typeMap.get(row.getField(3)).intValue();
                    }
                    int i5 = i4;
                    dArr[i5] = dArr[i5] + ((Double) row.getField(2)).doubleValue();
                    double d = dArr[i4] + this.weightStart[i4];
                    while ((d / this.weightTotal[i4]) * 10000.0d >= jArr2[i4] + 1) {
                        int i6 = i4;
                        jArr2[i6] = jArr2[i6] + 1;
                        collector.collect(Row.of(new Object[]{null, Long.valueOf(this.startIdx[i4] + jArr[i4]), null, Long.valueOf(-(jArr2[i4] + (10001 * i4) + 1))}));
                    }
                    collector.collect(Row.of(new Object[]{row.getField(0), row.getField(1), row.getField(2), Long.valueOf(this.startIdx[i4] + jArr[i4])}));
                    int i7 = i4;
                    jArr[i7] = jArr[i7] + 1;
                }
            }
        }).withBroadcastSet(mapPartition, "w2vInstVocabSize").withBroadcastSet(mapPartition2, "w2vWeightInstSum");
        if (z) {
            withBroadcastSet = (MapPartitionOperator) withBroadcastSet.withBroadcastSet(groupReduceOperator, "w2vGroupTypes");
        }
        MapOperator map = withBroadcastSet.filter(new FilterFunction<Row>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.8
            private static final long serialVersionUID = 4919758390927280039L;

            public boolean filter(Row row) throws Exception {
                return ((Long) row.getField(3)).longValue() > 0;
            }
        }).map(new MapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.7
            private static final long serialVersionUID = 7822668729101902036L;

            public Row map(Row row) throws Exception {
                return Row.of(new Object[]{row.getField(0), row.getField(1), row.getField(3)});
            }
        });
        GroupReduceOperator reduceGroup = withBroadcastSet.filter(new FilterFunction<Row>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.11
            private static final long serialVersionUID = -521522185349475218L;

            public boolean filter(Row row) throws Exception {
                return ((Long) row.getField(3)).longValue() < 0;
            }
        }).map(new MapFunction<Row, Tuple2<Long, Long>>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.10
            private static final long serialVersionUID = 5866505010266569566L;

            public Tuple2<Long, Long> map(Row row) throws Exception {
                return new Tuple2<>(Long.valueOf((-((Long) row.getField(3)).longValue()) - 1), (Long) row.getField(1));
            }
        }).reduceGroup(new RichGroupReduceFunction<Tuple2<Long, Long>, Long[]>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.9
            private static final long serialVersionUID = -3606835686371262547L;
            int size;

            public void open(Configuration configuration) throws Exception {
                this.size = 1;
                if (z) {
                    this.size = getRuntimeContext().getBroadcastVariable("w2vGroupTypes").size();
                }
            }

            public void reduce(Iterable<Tuple2<Long, Long>> iterable, Collector<Long[]> collector) throws Exception {
                HashMap hashMap = new HashMap();
                for (Tuple2<Long, Long> tuple2 : iterable) {
                    hashMap.put(tuple2.f0, tuple2.f1);
                }
                Long[] lArr = new Long[10001 * this.size];
                for (int i = 0; i < 10001 * this.size; i++) {
                    if (hashMap.containsKey(Long.valueOf(i))) {
                        lArr[i] = (Long) hashMap.get(Long.valueOf(i));
                    } else {
                        lArr[i] = lArr[i - 1];
                    }
                }
                collector.collect(lArr);
            }
        });
        if (z) {
            reduceGroup = (GroupReduceOperator) reduceGroup.withBroadcastSet(groupReduceOperator, "w2vGroupTypes");
        }
        SingleInputUdfOperator singleInputUdfOperator = null;
        if (z) {
            singleInputUdfOperator = groupReduceOperator.reduceGroup(new RichGroupReduceFunction<Tuple2<Object, Integer>, long[]>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.12
                private static final long serialVersionUID = -3498925783458257192L;
                long[] startIdx;
                int size;

                public void open(Configuration configuration) throws Exception {
                    this.size = getRuntimeContext().getBroadcastVariable("w2vGroupTypes").size();
                    List<Tuple2> broadcastVariable = getRuntimeContext().getBroadcastVariable("w2vInstVocabSize");
                    int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                    this.startIdx = new long[this.size];
                    long[] jArr = new long[this.size];
                    long[] jArr2 = new long[this.size];
                    for (Tuple2 tuple2 : broadcastVariable) {
                        for (int i = 0; i < this.size; i++) {
                            if (((Integer) tuple2.f0).intValue() < indexOfThisSubtask) {
                                long[] jArr3 = this.startIdx;
                                int i2 = i;
                                jArr3[i2] = jArr3[i2] + ((long[]) tuple2.f1)[i];
                            }
                            int i3 = i;
                            jArr[i3] = jArr[i3] + ((long[]) tuple2.f1)[i];
                            if (i == 0) {
                                jArr2[i] = jArr[i];
                            } else {
                                jArr2[i] = jArr2[i - 1] + jArr[i];
                            }
                        }
                    }
                    for (int i4 = 1; i4 < this.size; i4++) {
                        long[] jArr4 = this.startIdx;
                        int i5 = i4;
                        jArr4[i5] = jArr4[i5] + jArr2[i4 - 1];
                    }
                }

                public void reduce(Iterable<Tuple2<Object, Integer>> iterable, Collector<long[]> collector) throws Exception {
                    Object[] objArr = new Object[this.size];
                    long[] jArr = new long[this.size];
                    for (Tuple2<Object, Integer> tuple2 : iterable) {
                        objArr[((Integer) tuple2.f1).intValue()] = tuple2.f0;
                        jArr[((Integer) tuple2.f1).intValue()] = this.startIdx[((Integer) tuple2.f1).intValue()] + j;
                    }
                    collector.collect(jArr);
                }
            }).withBroadcastSet(mapPartition, "w2vInstVocabSize").withBroadcastSet(groupReduceOperator, "w2vGroupTypes");
        }
        return new Tuple3<>(map, reduceGroup, singleInputUdfOperator);
    }

    public static BatchOperator<?> transWord2Index(BatchOperator<?> batchOperator, String[] strArr, String[] strArr2, BatchOperator<?> batchOperator2) {
        return transWord2Index(batchOperator, strArr, strArr2, batchOperator2, "word", INDEX_COL_NAME);
    }

    public static BatchOperator<?> transWord2Index(BatchOperator<?> batchOperator, String[] strArr, String[] strArr2, BatchOperator<?> batchOperator2, String str, String str2) {
        return trans(batchOperator, strArr, strArr2, batchOperator2, str, str2, true, null);
    }

    public static BatchOperator<?> transDoc2IndexVector(BatchOperator<?> batchOperator, String[] strArr, String str, String[] strArr2, BatchOperator<?> batchOperator2) {
        return transDoc2IndexVector(batchOperator, strArr, str, strArr2, batchOperator2, "word", INDEX_COL_NAME);
    }

    public static BatchOperator<?> transDoc2IndexVector(BatchOperator<?> batchOperator, String[] strArr, String str, String[] strArr2, BatchOperator<?> batchOperator2, String str2, String str3) {
        return trans(batchOperator, strArr, strArr2, batchOperator2, str2, str3, false, str);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static BatchOperator<?> trans(BatchOperator<?> batchOperator, String[] strArr, String[] strArr2, BatchOperator<?> batchOperator2, String str, String str2, boolean z, String str3) {
        String[] colNames = batchOperator.getColNames();
        TypeInformation<?>[] colTypes = batchOperator.getColTypes();
        int[] findColIdx = findColIdx(strArr, colNames, colTypes);
        int[] findAppendColIdx = findAppendColIdx(strArr2, colNames);
        SingleInputUdfOperator withBroadcastSet = batchOperator.getDataSet().map(new GenContentMapping(findColIdx, findAppendColIdx, z, str3)).withBroadcastSet(batchOperator2.select(str + "," + str2).getDataSet(), "vocabulary");
        int length = findColIdx.length;
        int length2 = length + (strArr2 == null ? 0 : strArr2.length);
        String[] strArr3 = new String[length2];
        TypeInformation[] typeInformationArr = new TypeInformation[length2];
        int i = 0;
        while (i < length) {
            strArr3[i] = colNames[findColIdx[i]];
            typeInformationArr[i] = z ? Types.DOUBLE : Types.STRING;
            i++;
        }
        while (i < length2) {
            strArr3[i] = colNames[findAppendColIdx[i - length]];
            typeInformationArr[i] = colTypes[findAppendColIdx[i - length]];
            i++;
        }
        return (BatchOperator) new TableSourceBatchOp(DataSetConversionUtil.toTable(batchOperator.getMLEnvironmentId(), (DataSet<Row>) withBroadcastSet, strArr3, (TypeInformation<?>[]) typeInformationArr)).setMLEnvironmentId(batchOperator.getMLEnvironmentId());
    }

    private static int[] findColIdx(String[] strArr, String[] strArr2, TypeInformation<?>[] typeInformationArr) {
        int length = strArr.length;
        if (length < 1 || length > strArr2.length) {
            throw new AkIllegalArgumentException(String.format("selected column size out of range [%d, %d]", 1, Integer.valueOf(strArr2.length)));
        }
        int[] iArr = new int[length];
        for (int i = 0; i < length; i++) {
            iArr[i] = -1;
            int i2 = 0;
            while (true) {
                if (i2 >= strArr2.length) {
                    break;
                }
                if (!strArr[i].equals(strArr2[i2])) {
                    i2++;
                } else {
                    if (typeInformationArr[i2] != Types.STRING) {
                        throw new AkIllegalArgumentException(String.format("type of column: %s must be string.", strArr2[i2]));
                    }
                    iArr[i] = i2;
                }
            }
            if (iArr[i] == -1) {
                throw new AkIllegalArgumentException(String.format("column %s does not exist.", strArr[i]));
            }
        }
        return iArr;
    }

    private static int[] findAppendColIdx(String[] strArr, String[] strArr2) {
        if (strArr == null || strArr.length < 1) {
            return null;
        }
        int length = strArr.length;
        if (length > strArr2.length) {
            throw new AkIllegalArgumentException(String.format("selected append column size out of range [%d, %d]", 0, Integer.valueOf(strArr2.length)));
        }
        int[] iArr = new int[length];
        for (int i = 0; i < length; i++) {
            iArr[i] = -1;
            int i2 = 0;
            while (true) {
                if (i2 >= strArr2.length) {
                    break;
                }
                if (strArr[i].equals(strArr2[i2])) {
                    iArr[i] = i2;
                    break;
                }
                i2++;
            }
            if (iArr[i] == -1) {
                throw new AkIllegalArgumentException(String.format("column %s does not exist.", strArr[i]));
            }
        }
        return iArr;
    }

    public static DataSet<Tuple2<Long, Row>> localSort(DataSet<Tuple2<Integer, Row>> dataSet, DataSet<Tuple2<Integer, Long>> dataSet2, final int i) {
        return dataSet.partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.14
            private static final long serialVersionUID = -6835474886524807584L;

            public int partition(Integer num, int i2) {
                return num.intValue() % i2;
            }
        }, 0).mapPartition(new RichMapPartitionFunction<Tuple2<Integer, Row>, Tuple2<Long, Row>>() { // from class: com.alibaba.alink.operator.common.nlp.WordCountUtil.13
            private static final long serialVersionUID = 9152306255552921122L;
            transient long startIdx;

            public void open(Configuration configuration) throws Exception {
                List<Tuple2> broadcastVariable = getRuntimeContext().getBroadcastVariable("partitionCnt");
                this.startIdx = 0L;
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                for (Tuple2 tuple2 : broadcastVariable) {
                    if (((Integer) tuple2.f0).intValue() < indexOfThisSubtask) {
                        this.startIdx += ((Long) tuple2.f1).longValue();
                    }
                }
            }

            public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Tuple2<Long, Row>> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Iterator<Tuple2<Integer, Row>> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next().f1);
                }
                arrayList.sort(new SortUtils.RowComparator(i));
                long j = 0;
                Iterator it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    collector.collect(Tuple2.of(Long.valueOf(this.startIdx + j), (Row) it2.next()));
                    j++;
                }
            }
        }).name("local_sort").withBroadcastSet(dataSet2, "partitionCnt");
    }
}
