package com.alibaba.alink.operator.common.clustering.lda;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompleteResultFunction;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.operator.common.clustering.LdaModelData;
import com.google.common.collect.Lists;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.types.Row;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/lda/BuildEmLdaModel.class */
public class BuildEmLdaModel extends CompleteResultFunction {
    private static final Logger LOG = LoggerFactory.getLogger(BuildEmLdaModel.class);
    private static final long serialVersionUID = 8585414464252665425L;
    private int topicNum;
    private double alpha;
    private double beta;

    public BuildEmLdaModel(int i, double d, double d2) {
        this.topicNum = i;
        this.alpha = d;
        this.beta = d2;
    }

    @Override // com.alibaba.alink.common.comqueue.CompleteResultFunction
    public List<Row> calc(ComContext comContext) {
        if (comContext.getTaskId() != 0) {
            return null;
        }
        double[] dArr = (double[]) comContext.getObj(LdaVariable.logLikelihood);
        LOG.info("em logLikelihood: {}", Double.valueOf(dArr[0]));
        int intValue = ((Integer) ((List) comContext.getObj(LdaVariable.vocabularySize)).get(0)).intValue();
        DenseMatrix denseMatrix = new DenseMatrix(intValue + 1, this.topicNum, (double[]) comContext.getObj(LdaVariable.nWordTopics), false);
        double[] dArr2 = new double[this.topicNum];
        Arrays.fill(dArr2, this.alpha);
        double[] dArr3 = new double[this.topicNum];
        Arrays.fill(dArr3, this.beta);
        LdaModelData ldaModelData = new LdaModelData(this.topicNum, intValue, denseMatrix, dArr2, dArr3);
        ldaModelData.logLikelihood = dArr[0];
        ldaModelData.logPerplexity = (-ldaModelData.logLikelihood) / intValue;
        return Lists.newArrayList(new Row[]{Row.of(new Object[]{ldaModelData})});
    }
}
