package com.alibaba.alink.operator.common.clustering.lda;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompleteResultFunction;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.operator.common.clustering.LdaModelData;
import com.alibaba.alink.params.clustering.LdaTrainParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/lda/BuildOnlineLdaModel.class */
public class BuildOnlineLdaModel extends CompleteResultFunction {
    private static final long serialVersionUID = 6377566517589659547L;
    private int topicNum;
    private double beta;

    public BuildOnlineLdaModel(int i, double d) {
        this.topicNum = i;
        this.beta = d;
    }

    @Override // com.alibaba.alink.common.comqueue.CompleteResultFunction
    public List<Row> calc(ComContext comContext) {
        if (comContext.getTaskId() != 0) {
            return null;
        }
        int intValue = ((Integer) ((Tuple2) ((List) comContext.getObj(LdaVariable.shape)).get(0)).f1).intValue();
        DenseMatrix denseMatrix = (DenseMatrix) comContext.getObj(LdaVariable.alpha);
        LdaModelData ldaModelData = new LdaModelData(this.topicNum, intValue, new double[denseMatrix.numRows()], new double[this.topicNum], (DenseMatrix) comContext.getObj(LdaVariable.lambda));
        for (int i = 0; i < denseMatrix.numRows(); i++) {
            ldaModelData.alpha[i] = denseMatrix.get(i, 0);
        }
        Arrays.fill(ldaModelData.beta, this.beta);
        ldaModelData.optimizer = LdaTrainParams.Method.Online;
        long round = Math.round(((double[]) comContext.getObj(LdaVariable.nonEmptyWordCount))[0]);
        ldaModelData.logLikelihood = ((double[]) comContext.getObj(LdaVariable.logLikelihood))[0];
        ldaModelData.logPerplexity = (-ldaModelData.logLikelihood) / round;
        ArrayList arrayList = new ArrayList();
        arrayList.add(Row.of(new Object[]{ldaModelData}));
        return arrayList;
    }
}
