package com.alibaba.alink.operator.common.clustering.lda;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.utils.JsonConverter;
import java.lang.reflect.Type;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.commons.math3.special.Gamma;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;

/* loaded from: input_file:com/alibaba/alink/operator/common/clustering/lda/LdaUtil.class */
public class LdaUtil {
    public static double digamma(double d) {
        return Gamma.digamma(d);
    }

    private static DenseMatrix digamma(DenseMatrix denseMatrix) {
        DenseMatrix denseMatrix2 = new DenseMatrix(denseMatrix.numRows(), denseMatrix.numCols());
        MatVecOp.apply(denseMatrix, denseMatrix2, (Function<Double, Double>) (v0) -> {
            return digamma(v0);
        });
        return denseMatrix2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double trigamma(double d) {
        return Gamma.trigamma(d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DenseMatrix trigamma(DenseMatrix denseMatrix) {
        DenseMatrix denseMatrix2 = new DenseMatrix(denseMatrix.numRows(), denseMatrix.numCols());
        MatVecOp.apply(denseMatrix, denseMatrix2, (Function<Double, Double>) (v0) -> {
            return trigamma(v0);
        });
        return denseMatrix2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double logGamma(double d) {
        return Gamma.logGamma(d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DenseMatrix logGamma(DenseMatrix denseMatrix) {
        DenseMatrix denseMatrix2 = new DenseMatrix(denseMatrix.numRows(), denseMatrix.numCols());
        MatVecOp.apply(denseMatrix, denseMatrix2, (Function<Double, Double>) (v0) -> {
            return logGamma(v0);
        });
        return denseMatrix2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DenseMatrix dirichletExpectation(DenseMatrix denseMatrix) {
        DenseMatrix sumByCol = sumByCol(denseMatrix);
        DenseMatrix digamma = digamma(denseMatrix);
        DenseMatrix digamma2 = digamma(sumByCol);
        for (int i = 0; i < denseMatrix.numCols(); i++) {
            for (int i2 = 0; i2 < denseMatrix.numRows(); i2++) {
                digamma.set(i2, i, digamma.get(i2, i) - digamma2.get(0, i2));
            }
        }
        return digamma;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DenseMatrix dirichletExpectationVec(DenseMatrix denseMatrix) {
        DenseMatrix digamma = digamma(denseMatrix);
        digamma.plusEquals(-digamma(denseMatrix.sum()));
        return digamma;
    }

    public static void exp(DenseMatrix denseMatrix) {
        double[] data = denseMatrix.getData();
        for (int i = 0; i < data.length; i++) {
            data[i] = Math.exp(data[i]);
        }
    }

    public static DenseMatrix expDirichletExpectation(DenseMatrix denseMatrix) {
        DenseMatrix dirichletExpectation = dirichletExpectation(denseMatrix);
        exp(dirichletExpectation);
        return dirichletExpectation;
    }

    public static double[] getTopicDistributionMethod(SparseVector sparseVector, DenseMatrix denseMatrix, DenseMatrix denseMatrix2, int i, int i2, RandomDataGenerator randomDataGenerator) {
        return ((DenseMatrix) getTopicDistributionMethod(sparseVector, denseMatrix, denseMatrix2, geneGamma(i, i2, randomDataGenerator), i).f0).getColumn(0);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DenseMatrix geneGamma(int i, int i2, RandomDataGenerator randomDataGenerator) {
        double[] dArr = new double[i];
        for (int i3 = 0; i3 < i; i3++) {
            dArr[i3] = randomDataGenerator.nextGamma(i2, 1.0d / i2);
        }
        return vectorToMatrix(dArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Tuple2<DenseMatrix, DenseMatrix> getTopicDistributionMethod(SparseVector sparseVector, DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseMatrix denseMatrix3, int i) {
        if (sparseVector.numberOfValues() == 0) {
            return new Tuple2<>(DenseMatrix.zeros(1, i), DenseMatrix.zeros(1, i));
        }
        DenseMatrix vectorToMatrix = vectorToMatrix(sparseVector.getValues());
        DenseMatrix dirichletExpectationVec = dirichletExpectationVec(denseMatrix3);
        exp(dirichletExpectationVec);
        DenseMatrix selectRows = denseMatrix.selectRows(sparseVector.getIndices());
        DenseMatrix multiplies = selectRows.multiplies(dirichletExpectationVec);
        multiplies.plusEquals(1.0E-100d);
        double d = 1.0d;
        while (d > 0.001d) {
            DenseMatrix m134clone = denseMatrix3.m134clone();
            denseMatrix3 = elementWiseProduct(dirichletExpectationVec, selectRows.transpose().multiplies(elementWiseDivide(vectorToMatrix, multiplies)));
            denseMatrix3.plusEquals(denseMatrix2);
            dirichletExpectationVec = dirichletExpectationVec(denseMatrix3);
            exp(dirichletExpectationVec);
            multiplies = selectRows.multiplies(dirichletExpectationVec);
            multiplies.plusEquals(1.0E-100d);
            d = diffDenseMatrix(denseMatrix3, m134clone, i);
        }
        return new Tuple2<>(denseMatrix3, dirichletExpectationVec.multiplies(elementWiseDivide(vectorToMatrix, multiplies).transpose()));
    }

    public static DenseMatrix vectorToMatrix(double[] dArr) {
        DenseMatrix denseMatrix = new DenseMatrix(dArr.length, 1);
        for (int i = 0; i < dArr.length; i++) {
            denseMatrix.set(i, 0, dArr[i]);
        }
        return denseMatrix;
    }

    private static double diffDenseMatrix(DenseMatrix denseMatrix, DenseMatrix denseMatrix2, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < denseMatrix.numCols(); i2++) {
            for (int i3 = 0; i3 < denseMatrix.numRows(); i3++) {
                d += Math.abs(denseMatrix.get(i3, i2) - denseMatrix2.get(i3, i2));
            }
        }
        return d / i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DenseMatrix elementWiseProduct(DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        DenseMatrix denseMatrix3 = new DenseMatrix(denseMatrix.numRows(), denseMatrix.numCols());
        MatVecOp.apply(denseMatrix, denseMatrix2, denseMatrix3, (BiFunction<Double, Double, Double>) (d, d2) -> {
            return Double.valueOf(d.doubleValue() * d2.doubleValue());
        });
        return denseMatrix3;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DenseMatrix elementWiseDivide(DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        DenseMatrix denseMatrix3 = new DenseMatrix(denseMatrix.numRows(), denseMatrix.numCols());
        MatVecOp.apply(denseMatrix, denseMatrix2, denseMatrix3, (BiFunction<Double, Double, Double>) (d, d2) -> {
            return Double.valueOf(d.doubleValue() / d2.doubleValue());
        });
        return denseMatrix3;
    }

    private static DenseMatrix sumByCol(DenseMatrix denseMatrix) {
        int numRows = denseMatrix.numRows();
        int numCols = denseMatrix.numCols();
        DenseMatrix denseMatrix2 = new DenseMatrix(1, numRows);
        for (int i = 0; i < numRows; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < numCols; i2++) {
                d += denseMatrix.get(i, i2);
            }
            denseMatrix2.set(0, i, d);
        }
        return denseMatrix2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DenseMatrix sumByRow(DenseMatrix denseMatrix) {
        int numRows = denseMatrix.numRows();
        int numCols = denseMatrix.numCols();
        DenseMatrix denseMatrix2 = new DenseMatrix(1, numCols);
        for (int i = 0; i < numCols; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < numRows; i2++) {
                d += denseMatrix.get(i2, i);
            }
            denseMatrix2.set(0, i, d);
        }
        return denseMatrix2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double logSumExp(DenseMatrix denseMatrix) {
        double d = denseMatrix.get(0, 0);
        for (int i = 0; i < denseMatrix.numRows(); i++) {
            for (int i2 = 0; i2 < denseMatrix.numCols(); i2++) {
                if (denseMatrix.get(i, i2) > d) {
                    d = denseMatrix.get(i, i2);
                }
            }
        }
        double d2 = 0.0d;
        for (int i3 = 0; i3 < denseMatrix.numRows(); i3++) {
            for (int i4 = 0; i4 < denseMatrix.numCols(); i4++) {
                d2 += Math.exp(denseMatrix.get(i3, i4) - d);
            }
        }
        return d + Math.log(d2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [com.alibaba.alink.operator.common.clustering.lda.LdaUtil$1] */
    public static HashMap<Integer, String> setWordIdWeightTrain(List<String> list) {
        int size = list.size();
        Type type = new TypeReference<Tuple3<String, Double, Integer>>() { // from class: com.alibaba.alink.operator.common.clustering.lda.LdaUtil.1
        }.getType();
        HashMap<Integer, String> hashMap = new HashMap<>(size);
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            Tuple3 tuple3 = (Tuple3) JsonConverter.fromJson(it.next(), type);
            hashMap.put(tuple3.f2, tuple3.f0);
        }
        return hashMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [com.alibaba.alink.operator.common.clustering.lda.LdaUtil$2] */
    public static HashMap<String, Tuple2<Integer, Double>> setWordIdWeightPredict(List<String> list) {
        int size = list.size();
        Type type = new TypeReference<Tuple3<String, Double, Integer>>() { // from class: com.alibaba.alink.operator.common.clustering.lda.LdaUtil.2
        }.getType();
        HashMap<String, Tuple2<Integer, Double>> hashMap = new HashMap<>(size);
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            Tuple3 tuple3 = (Tuple3) JsonConverter.fromJson(it.next(), type);
            hashMap.put(tuple3.f0, Tuple2.of(tuple3.f2, tuple3.f1));
        }
        return hashMap;
    }
}
