package com.alibaba.alink.common.linalg;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.github.fommil.netlib.F2jBLAS;

/* loaded from: input_file:com/alibaba/alink/common/linalg/BLAS.class */
public class BLAS {
    private static final com.github.fommil.netlib.BLAS F2J_BLAS = new F2jBLAS();
    private static final com.github.fommil.netlib.BLAS NATIVE_BLAS = com.github.fommil.netlib.BLAS.getInstance();

    public static double asum(int i, double[] dArr, int i2) {
        return F2J_BLAS.dasum(i, dArr, i2, 1);
    }

    public static double asum(DenseVector denseVector) {
        return asum(denseVector.data.length, denseVector.data, 0);
    }

    public static double asum(SparseVector sparseVector) {
        return asum(sparseVector.values.length, sparseVector.values, 0);
    }

    public static void axpy(double d, double[] dArr, double[] dArr2) {
        AkPreconditions.checkArgument(dArr.length == dArr2.length, "Array dimension mismatched.");
        F2J_BLAS.daxpy(dArr.length, d, dArr, 1, dArr2, 1);
    }

    public static void axpy(double d, DenseVector denseVector, DenseVector denseVector2) {
        AkPreconditions.checkArgument(denseVector.data.length == denseVector2.data.length, "Vector dimension mismatched.");
        F2J_BLAS.daxpy(denseVector.data.length, d, denseVector.data, 1, denseVector2.data, 1);
    }

    public static void axpy(double d, SparseVector sparseVector, DenseVector denseVector) {
        if (sparseVector.size() != -1) {
            AkPreconditions.checkArgument(sparseVector.size() == denseVector.size(), "Vector dimension mismatched.");
            for (int i = 0; i < sparseVector.indices.length; i++) {
                double[] dArr = denseVector.data;
                int i2 = sparseVector.indices[i];
                dArr[i2] = dArr[i2] + (d * sparseVector.values[i]);
            }
            return;
        }
        int length = sparseVector.indices.length - 1;
        while (sparseVector.indices[length] >= denseVector.size()) {
            length--;
        }
        for (int i3 = 0; i3 <= length; i3++) {
            double[] dArr2 = denseVector.data;
            int i4 = sparseVector.indices[i3];
            dArr2[i4] = dArr2[i4] + (d * sparseVector.values[i3]);
        }
    }

    public static void axpy(double d, Vector vector, DenseVector denseVector) {
        if (vector instanceof SparseVector) {
            axpy(d, (SparseVector) vector, denseVector);
        } else {
            axpy(d, (DenseVector) vector, denseVector);
        }
    }

    public static void axpy(double d, DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        AkPreconditions.checkArgument(denseMatrix.m == denseMatrix2.m && denseMatrix.n == denseMatrix2.n, "Matrix dimension mismatched.");
        F2J_BLAS.daxpy(denseMatrix.data.length, d, denseMatrix.data, 1, denseMatrix2.data, 1);
    }

    public static void axpy(int i, double d, double[] dArr, int i2, double[] dArr2, int i3) {
        F2J_BLAS.daxpy(i, d, dArr, i2, 1, dArr2, i3, 1);
    }

    public static double dot(double[] dArr, double[] dArr2) {
        AkPreconditions.checkArgument(dArr.length == dArr2.length, "Array dimension mismatched.");
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr2[i];
        }
        return d;
    }

    public static double dot(DenseVector denseVector, DenseVector denseVector2) {
        return dot(denseVector.getData(), denseVector2.getData());
    }

    public static void scal(double d, double[] dArr) {
        F2J_BLAS.dscal(dArr.length, d, dArr, 1);
    }

    public static void scal(double d, double[] dArr, int i, int i2) {
        F2J_BLAS.dscal(i2, d, dArr, i, 1);
    }

    public static void scal(double d, DenseVector denseVector) {
        F2J_BLAS.dscal(denseVector.data.length, d, denseVector.data, 1);
    }

    public static void scal(double d, SparseVector sparseVector) {
        F2J_BLAS.dscal(sparseVector.values.length, d, sparseVector.values, 1);
    }

    public static void scal(double d, DenseMatrix denseMatrix) {
        F2J_BLAS.dscal(denseMatrix.data.length, d, denseMatrix.data, 1);
    }

    public static void gemm(double d, DenseMatrix denseMatrix, boolean z, DenseMatrix denseMatrix2, boolean z2, double d2, DenseMatrix denseMatrix3) {
        AkPreconditions.checkArgument((z ? denseMatrix.m : denseMatrix.n) == (z2 ? denseMatrix2.n : denseMatrix2.m) && (z ? denseMatrix.n : denseMatrix.m) == denseMatrix3.m && (z2 ? denseMatrix2.m : denseMatrix2.n) == denseMatrix3.n, "matrix size mismatched.");
        int numRows = denseMatrix3.numRows();
        int numCols = denseMatrix3.numCols();
        int numRows2 = z ? denseMatrix.numRows() : denseMatrix.numCols();
        int numRows3 = denseMatrix.numRows();
        int numRows4 = denseMatrix2.numRows();
        int numRows5 = denseMatrix3.numRows();
        NATIVE_BLAS.dgemm(z ? "T" : "N", z2 ? "T" : "N", numRows, numCols, numRows2, d, denseMatrix.getData(), numRows3, denseMatrix2.getData(), numRows4, d2, denseMatrix3.getData(), numRows5);
    }

    private static void gemvDimensionCheck(DenseMatrix denseMatrix, boolean z, DenseVector denseVector, DenseVector denseVector2) {
        if (z) {
            AkPreconditions.checkArgument(denseMatrix.numCols() == denseVector2.size() && denseMatrix.numRows() == denseVector.size(), "Matrix and vector size mismatched.");
        } else {
            AkPreconditions.checkArgument(denseMatrix.numRows() == denseVector2.size() && denseMatrix.numCols() == denseVector.size(), "Matrix and vector size mismatched.");
        }
    }

    private static void gemvDimensionCheck(DenseMatrix denseMatrix, boolean z, SparseVector sparseVector, DenseVector denseVector) {
        if (sparseVector.size() != -1) {
            if (z) {
                AkPreconditions.checkArgument(denseMatrix.numCols() == denseVector.size() && denseMatrix.numRows() == sparseVector.size(), "Matrix and vector size mismatched.");
                return;
            } else {
                AkPreconditions.checkArgument(denseMatrix.numRows() == denseVector.size() && denseMatrix.numCols() == sparseVector.size(), "Matrix and vector size mismatched.");
                return;
            }
        }
        if (z) {
            AkPreconditions.checkArgument(denseMatrix.numCols() == denseVector.size(), "Matrix and vector size mismatched.");
        } else {
            AkPreconditions.checkArgument(denseMatrix.numRows() == denseVector.size(), "Matrix and vector size mismatched.");
        }
    }

    public static void gemv(double d, DenseMatrix denseMatrix, boolean z, Vector vector, double d2, DenseVector denseVector) {
        if (vector instanceof SparseVector) {
            gemv(d, denseMatrix, z, (SparseVector) vector, d2, denseVector);
        } else {
            gemv(d, denseMatrix, z, (DenseVector) vector, d2, denseVector);
        }
    }

    public static void gemv(double d, DenseMatrix denseMatrix, boolean z, DenseVector denseVector, double d2, DenseVector denseVector2) {
        gemvDimensionCheck(denseMatrix, z, denseVector, denseVector2);
        NATIVE_BLAS.dgemv(z ? "T" : "N", denseMatrix.numRows(), denseMatrix.numCols(), d, denseMatrix.getData(), denseMatrix.numRows(), denseVector.getData(), 1, d2, denseVector2.getData(), 1);
    }

    public static void gemv(double d, DenseMatrix denseMatrix, boolean z, SparseVector sparseVector, double d2, DenseVector denseVector) {
        gemvDimensionCheck(denseMatrix, z, sparseVector, denseVector);
        int numRows = denseMatrix.numRows();
        int numCols = denseMatrix.numCols();
        if (!z) {
            int length = sparseVector.indices.length - 1;
            while (sparseVector.indices[length] >= numCols) {
                length--;
            }
            scal(d2, denseVector);
            for (int i = 0; i <= length; i++) {
                F2J_BLAS.daxpy(numRows, d * sparseVector.values[i], denseMatrix.data, sparseVector.indices[i] * numRows, 1, denseVector.data, 0, 1);
            }
            return;
        }
        int i2 = 0;
        int length2 = sparseVector.indices.length - 1;
        while (sparseVector.indices[length2] >= numRows) {
            length2--;
        }
        for (int i3 = 0; i3 < numCols; i3++) {
            double d3 = 0.0d;
            for (int i4 = 0; i4 <= length2; i4++) {
                d3 += sparseVector.values[i4] * denseMatrix.data[i2 + sparseVector.indices[i4]];
            }
            denseVector.data[i3] = (d2 * denseVector.data[i3]) + (d * d3);
            i2 += numRows;
        }
    }
}
