package com.alibaba.alink.operator.common.clustering.lda;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.List;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/lda/OnlineCorpusStep.class */
public class OnlineCorpusStep extends ComputeFunction {
    private static final long serialVersionUID = 8719846772171941445L;
    private int numTopic;
    private double subSamplingRate;
    private int gammaShape;
    private Integer seed;
    private RandomDataGenerator random = new RandomDataGenerator();
    private boolean addedIndex = false;

    public OnlineCorpusStep(int i, double d, int i2, Integer num) {
        this.numTopic = i;
        this.subSamplingRate = d;
        this.gammaShape = i2;
        this.seed = num;
    }

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        DenseMatrix denseMatrix;
        DenseMatrix denseMatrix2;
        if (!this.addedIndex && this.seed != null) {
            this.random.reSeed(this.seed.intValue());
            this.addedIndex = true;
        }
        int intValue = ((Integer) ((Tuple2) ((List) comContext.getObj(LdaVariable.shape)).get(0)).f1).intValue();
        List list = (List) comContext.getObj("data");
        if (comContext.getStepNo() == 1) {
            Tuple2 tuple2 = (Tuple2) ((List) comContext.getObj(LdaVariable.initModel)).get(0);
            denseMatrix = (DenseMatrix) tuple2.f0;
            denseMatrix2 = (DenseMatrix) tuple2.f1;
        } else {
            denseMatrix = (DenseMatrix) comContext.getObj(LdaVariable.lambda);
            denseMatrix2 = (DenseMatrix) comContext.getObj(LdaVariable.alpha);
        }
        if (list == null || list.size() == 0) {
            comContext.putObj(LdaVariable.wordTopicStat, new double[this.numTopic * intValue]);
            comContext.putObj(LdaVariable.logPhatPart, new double[this.numTopic]);
            comContext.putObj(LdaVariable.nonEmptyWordCount, new double[]{Criteria.INVALID_GAIN});
            comContext.putObj(LdaVariable.nonEmptyDocCount, new double[]{Criteria.INVALID_GAIN});
            return;
        }
        Tuple4<DenseMatrix, DenseMatrix, Long, Long> onlineCorpusUpdate = onlineCorpusUpdate(list, denseMatrix, denseMatrix2, null, intValue, this.numTopic, this.subSamplingRate, this.random, this.gammaShape);
        comContext.putObj(LdaVariable.wordTopicStat, ((DenseMatrix) onlineCorpusUpdate.f0).getData().clone());
        comContext.putObj(LdaVariable.logPhatPart, ((DenseMatrix) onlineCorpusUpdate.f1).getData().clone());
        comContext.putObj(LdaVariable.nonEmptyWordCount, new double[]{((Long) onlineCorpusUpdate.f2).longValue()});
        comContext.putObj(LdaVariable.nonEmptyDocCount, new double[]{((Long) onlineCorpusUpdate.f3).longValue()});
    }

    public static Tuple4<DenseMatrix, DenseMatrix, Long, Long> onlineCorpusUpdate(List<Vector> list, DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseMatrix denseMatrix3, int i, int i2, double d, RandomDataGenerator randomDataGenerator, int i3) {
        DenseMatrix zeros = DenseMatrix.zeros(i2, i);
        DenseMatrix denseMatrix4 = new DenseMatrix(i2, 1);
        DenseMatrix transpose = LdaUtil.expDirichletExpectation(denseMatrix).transpose();
        long j = 0;
        long j2 = 0;
        for (int i4 : generateOnlineDocs(list.size(), d, randomDataGenerator)) {
            SparseVector sparseVector = (SparseVector) list.get(i4);
            sparseVector.setSize(i);
            sparseVector.removeZeroValues();
            for (int i5 = 0; i5 < sparseVector.numberOfValues(); i5++) {
                j = (long) (j + sparseVector.getValues()[i5]);
            }
            Tuple2<DenseMatrix, DenseMatrix> topicDistributionMethod = LdaUtil.getTopicDistributionMethod(sparseVector, transpose, denseMatrix2, LdaUtil.geneGamma(i2, i3, randomDataGenerator), i2);
            for (int i6 = 0; i6 < sparseVector.getIndices().length; i6++) {
                for (int i7 = 0; i7 < i2; i7++) {
                    zeros.add(i7, sparseVector.getIndices()[i6], ((DenseMatrix) topicDistributionMethod.f1).get(i7, i6));
                }
            }
            DenseMatrix dirichletExpectationVec = LdaUtil.dirichletExpectationVec((DenseMatrix) topicDistributionMethod.f0);
            for (int i8 = 0; i8 < i2; i8++) {
                denseMatrix4.add(i8, 0, dirichletExpectationVec.get(i8, 0));
            }
            j2++;
        }
        return new Tuple4<>(zeros, denseMatrix4, Long.valueOf(j), Long.valueOf(j2));
    }

    private static int[] generateOnlineDocs(int i, double d, RandomDataGenerator randomDataGenerator) {
        int ceil = (int) Math.ceil(i * d);
        int[] iArr = new int[ceil];
        for (int i2 = 0; i2 < ceil; i2++) {
            iArr[i2] = randomDataGenerator.nextInt(0, i - 1);
        }
        return iArr;
    }
}
