package com.alibaba.alink.operator.common.statistics.basicstatistic;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.github.fommil.netlib.BLAS;
import com.github.fommil.netlib.F2jBLAS;
import com.github.fommil.netlib.LAPACK;
import com.google.common.primitives.Doubles;
import java.io.Serializable;
import org.netlib.util.intW;

/* loaded from: input_file:com/alibaba/alink/operator/common/statistics/basicstatistic/MultivariateGaussian.class */
public class MultivariateGaussian implements Serializable {
    private static final LAPACK LAPACK_INST = LAPACK.getInstance();
    private static final BLAS F2J_BLAS_INST = F2jBLAS.getInstance();
    private static final double EPSILON;
    private static final long serialVersionUID = -5155224606070313042L;
    private final DenseVector mean;
    private final DenseMatrix cov;
    private DenseMatrix rootSigmaInv;
    private double u;
    private transient ThreadLocal<DenseVector> threadLocalDelta;
    private transient ThreadLocal<DenseVector> threadLocalV;

    public MultivariateGaussian(DenseVector denseVector, DenseMatrix denseMatrix) {
        this.mean = denseVector;
        this.cov = denseMatrix;
        this.threadLocalDelta = ThreadLocal.withInitial(() -> {
            return DenseVector.zeros(denseVector.size());
        });
        this.threadLocalV = ThreadLocal.withInitial(() -> {
            return DenseVector.zeros(denseVector.size());
        });
        calculateCovarianceConstants();
    }

    public MultivariateGaussian(MultivariateGaussian multivariateGaussian) {
        this.mean = multivariateGaussian.mean;
        this.cov = multivariateGaussian.cov;
        this.threadLocalDelta = ThreadLocal.withInitial(() -> {
            return DenseVector.zeros(this.mean.size());
        });
        this.threadLocalV = ThreadLocal.withInitial(() -> {
            return DenseVector.zeros(this.mean.size());
        });
        this.rootSigmaInv = multivariateGaussian.rootSigmaInv;
        this.u = multivariateGaussian.u;
    }

    public double pdf(Vector vector) {
        return Math.exp(logpdf(vector));
    }

    public double logpdf(Vector vector) {
        DenseVector denseVector = this.threadLocalDelta.get();
        DenseVector denseVector2 = this.threadLocalV.get();
        System.arraycopy(this.mean.getData(), 0, denseVector.getData(), 0, this.mean.size());
        com.alibaba.alink.common.linalg.BLAS.scal(-1.0d, denseVector);
        if (vector instanceof DenseVector) {
            com.alibaba.alink.common.linalg.BLAS.axpy(1.0d, (DenseVector) vector, denseVector);
        } else if (vector instanceof SparseVector) {
            com.alibaba.alink.common.linalg.BLAS.axpy(1.0d, (SparseVector) vector, denseVector);
        }
        com.alibaba.alink.common.linalg.BLAS.gemv(1.0d, this.rootSigmaInv, true, denseVector, Criteria.INVALID_GAIN, denseVector2);
        return this.u - (0.5d * com.alibaba.alink.common.linalg.BLAS.dot(denseVector2, denseVector2));
    }

    private void calculateCovarianceConstants() {
        int size = this.mean.size();
        int i = (3 * size) - 1;
        double[] dArr = new double[size * size];
        double[] dArr2 = new double[i];
        double[] dArr3 = new double[size];
        intW intw = new intW(0);
        System.arraycopy(this.cov.getData(), 0, dArr, 0, size * size);
        LAPACK_INST.dsyev("V", "U", size, dArr, size, dArr3, dArr2, i, intw);
        double max = EPSILON * size * Doubles.max(dArr3);
        double d = 0.0d;
        for (double d2 : dArr3) {
            if (d2 > max) {
                d += Math.log(d2);
            }
        }
        for (int i2 = 0; i2 < size; i2++) {
            F2J_BLAS_INST.dscal(size, dArr3[i2] > max ? Math.sqrt(1.0d / dArr3[i2]) : Criteria.INVALID_GAIN, dArr, i2 * size, 1);
        }
        this.rootSigmaInv = new DenseMatrix(size, size, dArr);
        this.u = (-0.5d) * ((size * Math.log(6.283185307179586d)) + d);
    }

    static {
        double d = 1.0d;
        while (true) {
            double d2 = d;
            if (1.0d + (d2 / 2.0d) == 1.0d) {
                EPSILON = d2;
                return;
            }
            d = d2 / 2.0d;
        }
    }
}
