package com.alibaba.alink.operator.common.clustering.lda;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.flink.api.java.tuple.Tuple2;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/lda/OnlineLogLikelihood.class */
public class OnlineLogLikelihood extends ComputeFunction {
    private static final long serialVersionUID = -675551201422550951L;
    private double beta;
    private int numTopic;
    private int numIter;
    private int gammaShape;
    private Integer seed;
    private RandomDataGenerator random = new RandomDataGenerator();
    private boolean addedIndex = false;

    public OnlineLogLikelihood(double d, int i, int i2, int i3, Integer num) {
        this.beta = d;
        this.numTopic = i;
        this.numIter = i2;
        this.gammaShape = i3;
        this.seed = num;
    }

    public static double logLikelihood(List<Vector> list, DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseMatrix denseMatrix3, int i, int i2, double d, int i3, int i4, RandomDataGenerator randomDataGenerator) {
        boolean z = denseMatrix3 == null;
        DenseMatrix transpose = LdaUtil.dirichletExpectation(denseMatrix).transpose();
        DenseMatrix transpose2 = LdaUtil.expDirichletExpectation(denseMatrix).transpose();
        double d2 = 0.0d;
        if (list != null) {
            Iterator<Vector> it = list.iterator();
            while (it.hasNext()) {
                double d3 = 0.0d;
                SparseVector sparseVector = (SparseVector) it.next();
                sparseVector.removeZeroValues();
                if (z) {
                    denseMatrix3 = LdaUtil.geneGamma(i, i4, randomDataGenerator);
                }
                denseMatrix3 = (DenseMatrix) LdaUtil.getTopicDistributionMethod(sparseVector, transpose2, denseMatrix2, denseMatrix3, i).f0;
                DenseMatrix dirichletExpectationVec = LdaUtil.dirichletExpectationVec(denseMatrix3);
                for (int i5 = 0; i5 < sparseVector.numberOfValues(); i5++) {
                    d3 += sparseVector.getValues()[i5] * LdaUtil.logSumExp(dirichletExpectationVec.plus(new DenseMatrix(i, 1, transpose.getRow(sparseVector.getIndices()[i5]))));
                }
                d2 += d3 + LdaUtil.elementWiseProduct(denseMatrix2.minus(denseMatrix3), dirichletExpectationVec).sum() + LdaUtil.logGamma(denseMatrix3).minus(LdaUtil.logGamma(denseMatrix2)).sum() + (LdaUtil.logGamma(denseMatrix2.sum()) - LdaUtil.logGamma(denseMatrix3.sum()));
            }
        }
        return d2 + (((LdaUtil.elementWiseProduct(denseMatrix.transpose().plus(-d).scale(-1.0d), transpose).sum() + LdaUtil.logGamma(denseMatrix.transpose()).plus(-LdaUtil.logGamma(d)).sum()) - LdaUtil.logGamma(LdaUtil.sumByRow(denseMatrix.transpose())).plus(-LdaUtil.logGamma(d * i2)).sum()) / i3);
    }

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        int stepNo = comContext.getStepNo();
        if (!this.addedIndex && this.seed != null) {
            this.random.reSeed(this.seed.intValue());
            this.addedIndex = true;
        }
        if (stepNo == 1) {
            comContext.putObj(LdaVariable.logLikelihood, new double[1]);
        }
        if (stepNo == this.numIter) {
            double[] dArr = (double[]) comContext.getObj(LdaVariable.logLikelihood);
            int intValue = ((Integer) ((Tuple2) ((List) comContext.getObj(LdaVariable.shape)).get(0)).f1).intValue();
            dArr[0] = logLikelihood((List) comContext.getObj("data"), (DenseMatrix) comContext.getObj(LdaVariable.lambda), (DenseMatrix) comContext.getObj(LdaVariable.alpha), null, this.numTopic, intValue, this.beta, comContext.getNumTask(), this.gammaShape, this.random);
            comContext.putObj(LdaVariable.logLikelihood, dArr);
        }
    }
}
