package com.alibaba.alink.operator.common.clustering.lda;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.VectorIterator;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.List;
import org.apache.commons.math3.random.RandomDataGenerator;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/lda/EmCorpusStep.class */
public class EmCorpusStep extends ComputeFunction {
    private static final long serialVersionUID = 2345303930235043884L;
    private int numTopic;
    private double alpha;
    private double beta;
    private Integer seed;
    private boolean addedIndex = false;
    private RandomDataGenerator rand = new RandomDataGenerator();

    public EmCorpusStep(int i, double d, double d2, Integer num) {
        this.numTopic = i;
        this.alpha = d;
        this.beta = d2;
        this.seed = num;
    }

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        if (!this.addedIndex && this.seed != null) {
            this.rand.reSeed(this.seed.intValue());
            this.addedIndex = true;
        }
        int intValue = ((Integer) ((List) comContext.getObj(LdaVariable.vocabularySize)).get(0)).intValue();
        if (comContext.getStepNo() == 1) {
            DenseMatrix denseMatrix = new DenseMatrix(intValue + 1, this.numTopic);
            comContext.putObj(LdaVariable.nWordTopics, denseMatrix.getData());
            List<SparseVector> list = (List) comContext.getObj("data");
            if (list == null) {
                return;
            }
            int size = list.size();
            Document[] documentArr = new Document[size];
            DenseMatrix denseMatrix2 = new DenseMatrix(size, this.numTopic);
            int i = 0;
            for (SparseVector sparseVector : list) {
                int i2 = 0;
                for (double d : sparseVector.getValues()) {
                    i2 = (int) (i2 + d);
                }
                Document document = new Document(i2);
                int i3 = 0;
                VectorIterator it = sparseVector.iterator();
                while (it.hasNext()) {
                    int index = it.getIndex();
                    for (int i4 = 0; i4 < ((int) it.getValue()); i4++) {
                        int nextInt = this.rand.nextInt(0, this.numTopic - 1);
                        document.setWordIdxs(i3, index);
                        document.setTopicIdxs(i3, nextInt);
                        updateDocWordTopics(denseMatrix2, denseMatrix, i, index, intValue, nextInt, 1);
                        i3++;
                    }
                    it.next();
                }
                documentArr[i] = document;
                i++;
            }
            comContext.putObj(LdaVariable.corpus, documentArr);
            comContext.putObj(LdaVariable.nDocTopics, denseMatrix2);
            comContext.removeObj("data");
            return;
        }
        Document[] documentArr2 = (Document[]) comContext.getObj(LdaVariable.corpus);
        if (documentArr2 == null) {
            return;
        }
        DenseMatrix denseMatrix3 = (DenseMatrix) comContext.getObj(LdaVariable.nDocTopics);
        DenseMatrix denseMatrix4 = new DenseMatrix(intValue + 1, this.numTopic, (double[]) comContext.getObj(LdaVariable.nWordTopics), false);
        int i5 = 0;
        double[] dArr = new double[this.numTopic];
        for (Document document2 : documentArr2) {
            int length = document2.getLength();
            for (int i6 = 0; i6 < length; i6++) {
                int wordIdxs = document2.getWordIdxs(i6);
                updateDocWordTopics(denseMatrix3, denseMatrix4, i5, wordIdxs, intValue, document2.getTopicIdxs(i6), -1);
                double d2 = 0.0d;
                for (int i7 = 0; i7 < this.numTopic; i7++) {
                    d2 += ((denseMatrix4.get(wordIdxs, i7) + this.beta) * (denseMatrix3.get(i5, i7) + this.alpha)) / (denseMatrix4.get(intValue, i7) + (intValue * this.beta));
                    dArr[i7] = d2;
                }
                int findProbIdx = findProbIdx(dArr, this.rand.nextUniform(Criteria.INVALID_GAIN, 1.0d) * d2);
                document2.setTopicIdxs(i6, findProbIdx);
                updateDocWordTopics(denseMatrix3, denseMatrix4, i5, wordIdxs, intValue, findProbIdx, 1);
            }
            i5++;
        }
        DenseMatrix denseMatrix5 = new DenseMatrix(denseMatrix4.numRows(), denseMatrix4.numCols());
        for (Document document3 : documentArr2) {
            int length2 = document3.getLength();
            for (int i8 = 0; i8 < length2; i8++) {
                denseMatrix5.add(document3.getWordIdxs(i8), document3.getTopicIdxs(i8), 1.0d);
                denseMatrix5.add(intValue, document3.getTopicIdxs(i8), 1.0d);
            }
        }
        comContext.putObj(LdaVariable.nWordTopics, denseMatrix5.getData());
    }

    private int findProbIdx(double[] dArr, double d) {
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] >= d) {
                return i;
            }
        }
        return dArr.length - 1;
    }

    private void updateDocWordTopics(DenseMatrix denseMatrix, DenseMatrix denseMatrix2, int i, int i2, int i3, int i4, int i5) {
        denseMatrix.add(i, i4, i5);
        denseMatrix2.add(i2, i4, i5);
        denseMatrix2.add(i3, i4, i5);
    }
}
