package com.alibaba.alink.operator.batch.nlp;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompareCriterionFunction;
import com.alibaba.alink.common.comqueue.CompleteResultFunction;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.utils.ExpTableArray;
import com.alibaba.alink.common.utils.RowUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.batch.utils.WithTrainInfo;
import com.alibaba.alink.operator.common.dataproc.SortUtils;
import com.alibaba.alink.operator.common.nlp.Word2VecModelDataConverter;
import com.alibaba.alink.operator.common.nlp.Word2VecTrainInfo;
import com.alibaba.alink.operator.common.nlp.WordCountUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.nlp.Word2VecTrainParams;
import com.alibaba.alink.params.shared.tree.HasSeed;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = {TypeCollections.STRING_TYPES})
@NameCn("Word2Vec训练")
@NameEn("Word2Vec Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.nlp.Word2Vec")
/* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp.class */
public class Word2VecTrainBatchOp extends BatchOperator<Word2VecTrainBatchOp> implements Word2VecTrainParams<Word2VecTrainBatchOp>, WithTrainInfo<Word2VecTrainInfo, Word2VecTrainBatchOp> {
    private static final long serialVersionUID = -1901810620054339260L;
    private static final Logger LOG = LoggerFactory.getLogger(Word2VecTrainBatchOp.class);
    public static int MAX_CODE_LENGTH = 40;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp$AvgInputOutput.class */
    public static class AvgInputOutput extends ComputeFunction {
        private static final long serialVersionUID = -6272951344479535648L;

        private AvgInputOutput() {
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            double[] dArr = (double[]) comContext.getObj("input");
            for (int i = 0; i < dArr.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] / comContext.getNumTask();
            }
            double[] dArr2 = (double[]) comContext.getObj("output");
            for (int i3 = 0; i3 < dArr2.length; i3++) {
                int i4 = i3;
                dArr2[i4] = dArr2[i4] / comContext.getNumTask();
            }
            List list = (List) comContext.getObj("lossIterInfo");
            if (list == null) {
                list = new ArrayList();
                comContext.putObj("lossIterInfo", list);
            }
            list.add(Double.valueOf(((double[]) comContext.getObj("lossInfo"))[0] / comContext.getNumTask()));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp$CalcModel.class */
    private static class CalcModel {
        private final int vectorSize;
        private final int window;
        private final double alpha;
        private boolean randomWindow;
        private double[] input;
        private double[] output;
        private int taskId;
        private Word[] vocab;
        private Random random;

        public CalcModel(int i, long j, boolean z, int i2, double d, int i3, Word[] wordArr, double[] dArr, double[] dArr2) {
            this.vectorSize = i;
            this.randomWindow = z;
            this.window = i2;
            this.alpha = d;
            this.vocab = wordArr;
            this.input = dArr;
            this.output = dArr2;
            this.taskId = i3;
            this.random = new Random(j);
        }

        public double update(List<int[]> list) {
            int i;
            Word2VecTrainBatchOp.LOG.info("taskId: {}, map partition start", Integer.valueOf(this.taskId));
            double[] dArr = new double[this.vectorSize];
            double d = 0.0d;
            double d2 = 0.0d;
            for (int[] iArr : list) {
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    int nextInt = this.randomWindow ? this.random.nextInt(this.window) : 0;
                    int i3 = ((this.window * 2) + 1) - nextInt;
                    for (int i4 = nextInt; i4 < i3; i4++) {
                        if (i4 != this.window && (i = (i2 - this.window) + i4) >= 0 && i < iArr.length) {
                            int i5 = iArr[i] * this.vectorSize;
                            Arrays.fill(dArr, Criteria.INVALID_GAIN);
                            Word word = this.vocab[iArr[i2]];
                            int length = word.code.length;
                            d2 += 1.0d;
                            for (int i6 = 0; i6 < length; i6++) {
                                double d3 = 0.0d;
                                int i7 = word.point[i6] * this.vectorSize;
                                for (int i8 = 0; i8 < this.vectorSize; i8++) {
                                    d3 += this.input[i5 + i8] * this.output[i7 + i8];
                                }
                                if (d3 > -6.0d && d3 < 6.0d) {
                                    double d4 = ExpTableArray.sigmoidTable[(int) ((d3 + 6.0d) * 84.0d)];
                                    double d5 = ((1.0f - word.code[i6]) - d4) * this.alpha;
                                    d = word.code[i6] == 0 ? d + (-ExpTableArray.log(d4)) : d + (-ExpTableArray.log(1.0d - d4));
                                    for (int i9 = 0; i9 < this.vectorSize; i9++) {
                                        int i10 = i9;
                                        dArr[i10] = dArr[i10] + (d5 * this.output[i7 + i9]);
                                    }
                                    for (int i11 = 0; i11 < this.vectorSize; i11++) {
                                        double[] dArr2 = this.output;
                                        int i12 = i7 + i11;
                                        dArr2[i12] = dArr2[i12] + (d5 * this.input[i5 + i11]);
                                    }
                                }
                            }
                            for (int i13 = 0; i13 < this.vectorSize; i13++) {
                                double[] dArr3 = this.input;
                                int i14 = i5 + i13;
                                dArr3[i14] = dArr3[i14] + dArr[i13];
                            }
                        }
                    }
                }
            }
            Word2VecTrainBatchOp.LOG.info("taskId: {}, map partition end", Integer.valueOf(this.taskId));
            return d2 == Criteria.INVALID_GAIN ? Criteria.INVALID_GAIN : d / d2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp$CreateVocab.class */
    public static class CreateVocab extends RichGroupReduceFunction<Row, Tuple3<Integer, String, Word>> {
        private static final long serialVersionUID = 5918268703417386926L;
        int vocSize;

        private CreateVocab() {
        }

        public void open(Configuration configuration) throws Exception {
            this.vocSize = ((Integer) getRuntimeContext().getBroadcastVariableWithInitializer("vocSize", new BroadcastVariableInitializer<Long, Integer>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.CreateVocab.1
                public Integer initializeBroadcastVariable(Iterable<Long> iterable) {
                    return Integer.valueOf(iterable.iterator().next().intValue());
                }

                /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                public /* bridge */ /* synthetic */ Object m280initializeBroadcastVariable(Iterable iterable) {
                    return initializeBroadcastVariable((Iterable<Long>) iterable);
                }
            })).intValue();
        }

        public void reduce(Iterable<Row> iterable, Collector<Tuple3<Integer, String, Word>> collector) throws Exception {
            String[] strArr = new String[this.vocSize];
            Word[] wordArr = new Word[this.vocSize];
            for (Row row : iterable) {
                Word word = new Word();
                word.cnt = ((Long) row.getField(1)).longValue();
                wordArr[((Integer) row.getField(2)).intValue()] = word;
                strArr[((Integer) row.getField(2)).intValue()] = (String) row.getField(0);
            }
            Word2VecTrainBatchOp.createBinaryTree(wordArr);
            for (int i = 0; i < wordArr.length; i++) {
                collector.collect(Tuple3.of(Integer.valueOf(i), strArr[i], wordArr[i]));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp$Criterion.class */
    public static class Criterion extends CompareCriterionFunction {
        private static final long serialVersionUID = -5209402952030754112L;
        Params params;

        public Criterion(Params params) {
            this.params = params;
        }

        @Override // com.alibaba.alink.common.comqueue.CompareCriterionFunction
        public boolean calc(ComContext comContext) {
            return comContext.getStepNo() - 1 == ((Integer) ((List) comContext.getObj("syncNum")).get(0)).intValue() * ((Integer) this.params.get(Word2VecTrainParams.NUM_ITER)).intValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp$InitialVocabAndBuffer.class */
    public static class InitialVocabAndBuffer extends ComputeFunction {
        private static final long serialVersionUID = -5099694286000869372L;
        Params params;

        public InitialVocabAndBuffer(Params params) {
            this.params = params;
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            if (comContext.getStepNo() == 1) {
                int intValue = ((Integer) this.params.get(Word2VecTrainParams.VECTOR_SIZE)).intValue();
                List list = (List) comContext.getObj("vocSize");
                List list2 = (List) comContext.getObj("initialModel");
                List list3 = (List) comContext.getObj("vocabWithoutWordStr");
                int intValue2 = ((Long) list.get(0)).intValue();
                Object obj = new double[intValue * intValue2];
                Word[] wordArr = new Word[intValue2];
                for (int i = 0; i < intValue2; i++) {
                    Tuple2 tuple2 = (Tuple2) list2.get(i);
                    System.arraycopy(tuple2.f1, 0, obj, ((Integer) tuple2.f0).intValue() * intValue, intValue);
                    Tuple2 tuple22 = (Tuple2) list3.get(i);
                    wordArr[((Integer) tuple22.f0).intValue()] = (Word) tuple22.f1;
                }
                comContext.putObj("input", obj);
                comContext.putObj("output", new double[intValue * (intValue2 - 1)]);
                comContext.putObj("vocab", wordArr);
                comContext.removeObj("initialModel");
                comContext.removeObj("vocabWithoutWordStr");
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp$SerializeModel.class */
    public static class SerializeModel extends CompleteResultFunction {
        private static final long serialVersionUID = -6244849890744256651L;
        Params params;

        public SerializeModel(Params params) {
            this.params = params;
        }

        @Override // com.alibaba.alink.common.comqueue.CompleteResultFunction
        public List<Row> calc(ComContext comContext) {
            if (comContext.getTaskId() != 0) {
                return null;
            }
            List list = (List) comContext.getObj("lossIterInfo");
            int intValue = ((Long) ((List) comContext.getObj("vocSize")).get(0)).intValue();
            int intValue2 = ((Integer) this.params.get(Word2VecTrainParams.VECTOR_SIZE)).intValue();
            ArrayList arrayList = new ArrayList(intValue);
            arrayList.add(Row.of(new Object[]{0, 0, null, this.params.set((ParamInfo<ParamInfo<Double[]>>) Word2VecTrainInfo.LOSS, (ParamInfo<Double[]>) list.toArray(new Double[0])).set((ParamInfo<ParamInfo<Long>>) Word2VecTrainInfo.NUM_VOCAB, (ParamInfo<Long>) Long.valueOf(intValue)).toJson()}));
            double[] dArr = (double[]) comContext.getObj("input");
            for (int i = 0; i < intValue; i++) {
                DenseVector denseVector = new DenseVector(intValue2);
                System.arraycopy(dArr, i * intValue2, denseVector.getData(), 0, intValue2);
                arrayList.add(Row.of(new Object[]{1, Integer.valueOf(i), denseVector, null}));
            }
            return arrayList;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp$UpdateModel.class */
    public static class UpdateModel extends ComputeFunction {
        private static final long serialVersionUID = -200466448350631442L;
        Params params;

        public UpdateModel(Params params) {
            this.params = params;
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            List list = (List) comContext.getObj("trainData");
            int intValue = ((Integer) ((List) comContext.getObj("syncNum")).get(0)).intValue();
            if (comContext.getObj("lossInfo") == null) {
                comContext.putObj("lossInfo", new double[]{Criteria.INVALID_GAIN});
            }
            double[] dArr = (double[]) comContext.getObj("lossInfo");
            if (list == null) {
                return;
            }
            DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
            long startPos = defaultDistributedInfo.startPos((comContext.getStepNo() - 1) % intValue, intValue, list.size());
            dArr[0] = new CalcModel(((Integer) this.params.get(Word2VecTrainParams.VECTOR_SIZE)).intValue(), ((Long) this.params.get(HasSeed.SEED)).longValue() + comContext.getTaskId(), Boolean.parseBoolean((String) this.params.get(Word2VecTrainParams.RANDOM_WINDOW)), ((Integer) this.params.get(Word2VecTrainParams.WINDOW)).intValue(), ((Double) this.params.get(Word2VecTrainParams.ALPHA)).doubleValue(), comContext.getTaskId(), (Word[]) comContext.getObj("vocab"), (double[]) comContext.getObj("input"), (double[]) comContext.getObj("output")).update(list.subList((int) startPos, (int) (startPos + defaultDistributedInfo.localRowCnt((comContext.getStepNo() - 1) % intValue, intValue, list.size()))));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp$UseVocabWithoutWordString.class */
    public static class UseVocabWithoutWordString implements MapFunction<Tuple3<Integer, String, Word>, Tuple2<Integer, Word>> {
        private static final long serialVersionUID = -9049426378553185090L;

        private UseVocabWithoutWordString() {
        }

        public Tuple2<Integer, Word> map(Tuple3<Integer, String, Word> tuple3) throws Exception {
            return Tuple2.of(tuple3.f0, tuple3.f2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp$Word.class */
    public static class Word implements Serializable {
        private static final long serialVersionUID = 7064713372411549086L;
        public long cnt;
        public int[] point;
        public int[] code;

        private Word() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp$initialModel.class */
    public static class initialModel extends RichMapPartitionFunction<Tuple2<Integer, Word>, Tuple2<Integer, double[]>> {
        private static final long serialVersionUID = -5113354983404028347L;
        private final long seed;
        private final int vectorSize;
        Random random = new Random();

        public initialModel(long j, int i) {
            this.seed = j;
            this.vectorSize = i;
        }

        public void open(Configuration configuration) throws Exception {
            this.random.setSeed(this.seed + getRuntimeContext().getIndexOfThisSubtask());
        }

        public void mapPartition(Iterable<Tuple2<Integer, Word>> iterable, Collector<Tuple2<Integer, double[]>> collector) throws Exception {
            for (Tuple2<Integer, Word> tuple2 : iterable) {
                double[] dArr = new double[this.vectorSize];
                for (int i = 0; i < this.vectorSize; i++) {
                    dArr[i] = this.random.nextFloat();
                }
                collector.collect(Tuple2.of(tuple2.f0, dArr));
            }
        }
    }

    public Word2VecTrainBatchOp() {
    }

    public Word2VecTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void createBinaryTree(Word[] wordArr) {
        int i;
        int i2;
        int length = wordArr.length;
        int[] iArr = new int[MAX_CODE_LENGTH];
        int[] iArr2 = new int[MAX_CODE_LENGTH];
        long[] jArr = new long[(length * 2) - 1];
        int[] iArr3 = new int[(length * 2) - 1];
        int[] iArr4 = new int[(length * 2) - 1];
        for (int i3 = 0; i3 < length; i3++) {
            jArr[i3] = wordArr[i3].cnt;
        }
        Arrays.fill(jArr, length, (length * 2) - 1, 2147483647L);
        int i4 = length - 1;
        int i5 = length;
        for (int i6 = 0; i6 < length - 1; i6++) {
            if (i4 < 0) {
                i = i5;
                i5++;
            } else if (jArr[i4] < jArr[i5]) {
                i = i4;
                i4--;
            } else {
                i = i5;
                i5++;
            }
            if (i4 < 0) {
                i2 = i5;
                i5++;
            } else if (jArr[i4] < jArr[i5]) {
                i2 = i4;
                i4--;
            } else {
                i2 = i5;
                i5++;
            }
            jArr[length + i6] = jArr[i] + jArr[i2];
            iArr4[i] = length + i6;
            iArr4[i2] = length + i6;
            iArr3[i2] = 1;
        }
        for (int i7 = 0; i7 < length; i7++) {
            int i8 = i7;
            int i9 = 0;
            do {
                iArr2[i9] = iArr3[i8];
                iArr[i9] = i8;
                i9++;
                i8 = iArr4[i8];
            } while (i8 != (length * 2) - 2);
            wordArr[i7].code = new int[i9];
            for (int i10 = 0; i10 < i9; i10++) {
                wordArr[i7].code[(i9 - i10) - 1] = iArr2[i10];
            }
            wordArr[i7].point = new int[i9];
            wordArr[i7].point[0] = length - 2;
            for (int i11 = 1; i11 < i9; i11++) {
                wordArr[i7].point[i9 - i11] = iArr[i11] - length;
            }
        }
    }

    private static DataSet<Row> sortedIndexVocab(DataSet<Row> dataSet) {
        PartitionOperator partitionCustom = ((DataSet) SortUtils.pSort(dataSet, 1).f0).partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.1
            private static final long serialVersionUID = 7033675545004935349L;

            public int partition(Integer num, int i) {
                return num.intValue();
            }
        }, 0);
        return partitionCustom.mapPartition(new RichMapPartitionFunction<Tuple2<Integer, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.2
            private static final long serialVersionUID = -8439325113876456518L;
            int start;
            int curLen;
            int total;

            public void open(Configuration configuration) throws Exception {
                List<Tuple2> broadcastVariable = getRuntimeContext().getBroadcastVariable(WordCountUtil.COUNT_COL_NAME);
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                this.start = 0;
                this.curLen = 0;
                this.total = 0;
                for (Tuple2 tuple2 : broadcastVariable) {
                    if (((Integer) tuple2.f0).intValue() < indexOfThisSubtask) {
                        this.start = (int) (this.start + ((Long) tuple2.f1).longValue());
                    }
                    if (((Integer) tuple2.f0).intValue() == indexOfThisSubtask) {
                        this.curLen = ((Long) tuple2.f1).intValue();
                    }
                    this.total += ((Long) tuple2.f1).intValue();
                }
            }

            public void mapPartition(Iterable<Tuple2<Integer, Row>> iterable, Collector<Row> collector) throws Exception {
                if (this.curLen <= 0) {
                    return;
                }
                Row[] rowArr = new Row[this.curLen];
                int i = 0;
                Iterator<Tuple2<Integer, Row>> it = iterable.iterator();
                while (it.hasNext()) {
                    int i2 = i;
                    i++;
                    rowArr[i2] = (Row) it.next().f1;
                }
                Arrays.sort(rowArr, (row, row2) -> {
                    return (int) (((Long) row.getField(1)).longValue() - ((Long) row2.getField(1)).longValue());
                });
                int i3 = this.start;
                for (Row row3 : rowArr) {
                    collector.collect(RowUtil.merge(row3, Integer.valueOf(-((i3 - this.total) + 1))));
                    i3++;
                }
            }
        }).withBroadcastSet(DataSetUtils.countElementsPerPartition(partitionCustom), WordCountUtil.COUNT_COL_NAME);
    }

    private static DataSet<int[]> encodeContent(DataSet<String[]> dataSet, DataSet<Tuple3<Integer, String, Word>> dataSet2) {
        return dataSet.mapPartition(new RichMapPartitionFunction<String[], Tuple4<Integer, Long, Integer, String>>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.5
            private static final long serialVersionUID = 2985519984072344725L;

            public void mapPartition(Iterable<String[]> iterable, Collector<Tuple4<Integer, Long, Integer, String>> collector) throws Exception {
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                long j = 0;
                for (String[] strArr : iterable) {
                    if (strArr != null && strArr.length != 0) {
                        for (int i = 0; i < strArr.length; i++) {
                            collector.collect(new Tuple4(Integer.valueOf(indexOfThisSubtask), Long.valueOf(j), Integer.valueOf(i), strArr[i]));
                        }
                        j++;
                    }
                }
            }
        }).coGroup(dataSet2).where(new int[]{3}).equalTo(new int[]{1}).with(new CoGroupFunction<Tuple4<Integer, Long, Integer, String>, Tuple3<Integer, String, Word>, Tuple4<Integer, Long, Integer, Integer>>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.4
            private static final long serialVersionUID = -4187624436127997613L;

            public void coGroup(Iterable<Tuple4<Integer, Long, Integer, String>> iterable, Iterable<Tuple3<Integer, String, Word>> iterable2, Collector<Tuple4<Integer, Long, Integer, Integer>> collector) {
                for (Tuple3<Integer, String, Word> tuple3 : iterable2) {
                    for (Tuple4<Integer, Long, Integer, String> tuple4 : iterable) {
                        collector.collect(Tuple4.of(tuple4.f0, tuple4.f1, tuple4.f2, tuple3.getField(0)));
                    }
                }
            }
        }).groupBy(new int[]{0, 1}).reduceGroup(new GroupReduceFunction<Tuple4<Integer, Long, Integer, Integer>, int[]>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.3
            private static final long serialVersionUID = 8323725437283683721L;

            public void reduce(Iterable<Tuple4<Integer, Long, Integer, Integer>> iterable, Collector<int[]> collector) {
                ArrayList arrayList = new ArrayList();
                for (Tuple4<Integer, Long, Integer, Integer> tuple4 : iterable) {
                    arrayList.add(Tuple2.of(tuple4.f2, tuple4.f3));
                }
                Collections.sort(arrayList, new Comparator<Tuple2<Integer, Integer>>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.3.1
                    @Override // java.util.Comparator
                    public int compare(Tuple2<Integer, Integer> tuple2, Tuple2<Integer, Integer> tuple22) {
                        return ((Integer) tuple2.f0).compareTo((Integer) tuple22.f0);
                    }
                });
                int[] iArr = new int[arrayList.size()];
                for (int i = 0; i < arrayList.size(); i++) {
                    iArr[i] = ((Integer) ((Tuple2) arrayList.get(i)).f1).intValue();
                }
                collector.collect(iArr);
            }
        });
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public Word2VecTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        int intValue = getVectorSize().intValue();
        DataSet<Row> sortedIndexVocab = sortedIndexVocab(WordCountUtil.splitDocAndCount(checkAndGetFirst, getSelectedCol(), getWordDelimiter()).filter("cnt >= " + String.valueOf(getMinCount())).getDataSet());
        MapOperator map = DataSetUtils.countElementsPerPartition(sortedIndexVocab).sum(1).map(new MapFunction<Tuple2<Integer, Long>, Long>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.6
            private static final long serialVersionUID = -4608562797891228404L;

            public Long map(Tuple2<Integer, Long> tuple2) throws Exception {
                return (Long) tuple2.f1;
            }
        });
        PartitionOperator rebalance = sortedIndexVocab.reduceGroup(new CreateVocab()).withBroadcastSet(map, "vocSize").rebalance();
        PartitionOperator rebalance2 = encodeContent(checkAndGetFirst.select("`" + getSelectedCol() + "`").getDataSet().flatMap(new WordCountUtil.WordSpliter(getWordDelimiter())).rebalance(), rebalance).rebalance();
        MapOperator map2 = rebalance.map(new UseVocabWithoutWordString());
        DataSet<Row> exec = new IterativeComQueue().initWithPartitionedData("trainData", rebalance2).initWithBroadcastData("vocSize", map).initWithBroadcastData("initialModel", map2.mapPartition(new initialModel(((Long) getParams().get(HasSeed.SEED)).longValue(), intValue)).rebalance()).initWithBroadcastData("vocabWithoutWordStr", map2).initWithBroadcastData("syncNum", DataSetUtils.countElementsPerPartition(rebalance2).sum(1).map(new RichMapFunction<Tuple2<Integer, Long>, Integer>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.7
            private static final long serialVersionUID = 2989627778876891178L;

            public Integer map(Tuple2<Integer, Long> tuple2) throws Exception {
                return Integer.valueOf(Math.max((int) (((Long) tuple2.f1).longValue() / 100000), Math.min(Math.max(1, (int) (((Long) tuple2.f1).longValue() / getRuntimeContext().getNumberOfParallelSubtasks())), 5)));
            }
        })).add(new InitialVocabAndBuffer(getParams())).add(new UpdateModel(getParams())).add(new AllReduce("input")).add(new AllReduce("output")).add(new AllReduce("lossInfo")).add(new AvgInputOutput()).setCompareCriterionOfNode0((CompareCriterionFunction) new Criterion(getParams())).closeWith(new SerializeModel(getParams())).exec();
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) exec.filter(new FilterFunction<Row>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.9
            private static final long serialVersionUID = -4087427554304465241L;

            public boolean filter(Row row) throws Exception {
                return ((Integer) row.getField(0)).intValue() == 0;
            }
        }).map(new MapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.8
            private static final long serialVersionUID = 2279565281381999504L;

            public Row map(Row row) throws Exception {
                return Row.of(new Object[]{row.getField(3)});
            }
        }), new String[]{"info"}, (TypeInformation<?>[]) new TypeInformation[]{Types.STRING})});
        setOutput((DataSet<Row>) exec.filter(new FilterFunction<Row>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.14
            private static final long serialVersionUID = -4087427554304465241L;

            public boolean filter(Row row) throws Exception {
                return ((Integer) row.getField(0)).intValue() == 1;
            }
        }).map(new MapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.13
            private static final long serialVersionUID = 2279565281381999504L;

            public Row map(Row row) throws Exception {
                return Row.of(new Object[]{row.getField(1), row.getField(2)});
            }
        }).map(new MapFunction<Row, Tuple2<Integer, DenseVector>>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.12
            private static final long serialVersionUID = 10165543447930471L;

            public Tuple2<Integer, DenseVector> map(Row row) throws Exception {
                return Tuple2.of((Integer) row.getField(0), (DenseVector) row.getField(1));
            }
        }).join(rebalance).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Tuple2<Integer, DenseVector>, Tuple3<Integer, String, Word>, Row>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.11
            private static final long serialVersionUID = 5611294863047638770L;

            public Row join(Tuple2<Integer, DenseVector> tuple2, Tuple3<Integer, String, Word> tuple3) throws Exception {
                return Row.of(new Object[]{tuple3.f1, tuple2.f1});
            }
        }).mapPartition(new MapPartitionFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp.10
            private static final long serialVersionUID = -3274399290123772498L;

            public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
                Word2VecModelDataConverter word2VecModelDataConverter = new Word2VecModelDataConverter();
                word2VecModelDataConverter.modelRows = (List) StreamSupport.stream(iterable.spliterator(), false).collect(Collectors.toList());
                word2VecModelDataConverter.save2(word2VecModelDataConverter, collector);
            }
        }), new Word2VecModelDataConverter().getModelSchema());
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public Word2VecTrainInfo createTrainInfo(List<Row> list) {
        return new Word2VecTrainInfo(list);
    }

    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public BatchOperator<?> getSideOutputTrainInfo() {
        return getSideOutput(0);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ Word2VecTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }

    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public /* bridge */ /* synthetic */ Word2VecTrainInfo createTrainInfo(List list) {
        return createTrainInfo((List<Row>) list);
    }
}
