package com.alibaba.alink.operator.common.regression;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.jama.JMatrixFunc;
import com.alibaba.alink.common.probabilistic.CDF;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.ArrayList;

/* loaded from: input_file:com/alibaba/alink/operator/common/regression/LinearReg.class */
public class LinearReg {
    public static LinearRegressionModel train(SummaryResultTable summaryResultTable, String str, String[] strArr) throws Exception {
        if (summaryResultTable == null) {
            throw new Exception("srt must not null!");
        }
        String[] strArr2 = summaryResultTable.colNames;
        Class[] clsArr = new Class[strArr2.length];
        for (int i = 0; i < strArr2.length; i++) {
            clsArr[i] = summaryResultTable.col(i).dataType;
        }
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(strArr2, str);
        Class cls = clsArr[findColIndexWithAssertAndHint];
        if (cls != Double.class && cls != Long.class && cls != Integer.class) {
            throw new Exception("col type must be double or bigint!");
        }
        if (strArr.length == 0) {
            throw new Exception("nameX must input!");
        }
        for (String str2 : strArr) {
            Class cls2 = clsArr[TableUtil.findColIndexWithAssertAndHint(strArr2, str2)];
            if (cls2 != Double.class && cls2 != Long.class && cls2 != Integer.class) {
                throw new Exception("col type must be double or bigint!");
            }
        }
        int length = strArr.length;
        int[] iArr = new int[length];
        for (int i2 = 0; i2 < length; i2++) {
            iArr[i2] = TableUtil.findColIndexWithAssert(summaryResultTable.colNames, strArr[i2]);
        }
        return train(summaryResultTable, findColIndexWithAssertAndHint, iArr, str, strArr);
    }

    public static LinearRegressionModel loadModel(String str) throws Exception {
        return null;
    }

    public static void predict(LinearRegressionModel linearRegressionModel, String str, String[] strArr, String[] strArr2, String str2, String str3) throws Exception {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void write(LinearRegressionModel linearRegressionModel, String str) throws Exception {
    }

    static LinearRegressionModel train(SummaryResultTable summaryResultTable, int i, int[] iArr) throws Exception {
        int length = iArr.length;
        String[] strArr = new String[length];
        for (int i2 = 0; i2 < length; i2++) {
            strArr[i2] = summaryResultTable.colNames[iArr[i2]];
        }
        return train(summaryResultTable, i, iArr, summaryResultTable.colNames[i], strArr);
    }

    private static LinearRegressionModel train(SummaryResultTable summaryResultTable, int i, int[] iArr, String str, String[] strArr) throws Exception {
        if (summaryResultTable.col(i).countMissValue > 0 || summaryResultTable.col(i).countNanValue > 0) {
            throw new Exception("col " + str + " has null value or nan value!");
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (summaryResultTable.col(iArr[i2]).countMissValue > 0 || summaryResultTable.col(iArr[i2]).countNanValue > 0) {
                throw new Exception("col " + strArr[i2] + " has null value or nan value!");
            }
        }
        if (summaryResultTable.col(0).countTotal == 0) {
            throw new Exception("table is empty!");
        }
        if (summaryResultTable.col(0).countTotal < strArr.length) {
            throw new Exception("record size Less than features size!");
        }
        int length = iArr.length;
        long j = summaryResultTable.col(i).count;
        if (j == 0) {
            throw new Exception("Y valid value num is zero!");
        }
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (summaryResultTable.col(iArr[i3]).count != 0) {
                arrayList.add(strArr[i3]);
            }
        }
        arrayList.toArray(strArr);
        LinearRegressionModel linearRegressionModel = new LinearRegressionModel(j, str, strArr);
        double[] dArr = new double[length];
        for (int i4 = 0; i4 < length; i4++) {
            dArr[i4] = summaryResultTable.col(iArr[i4]).mean();
        }
        double mean = summaryResultTable.col(i).mean();
        double[][] cov = summaryResultTable.getCov();
        DenseMatrix denseMatrix = new DenseMatrix(length, length);
        for (int i5 = 0; i5 < length; i5++) {
            for (int i6 = 0; i6 < length; i6++) {
                denseMatrix.set(i5, i6, cov[iArr[i5]][iArr[i6]]);
            }
        }
        DenseMatrix denseMatrix2 = new DenseMatrix(length, 1);
        for (int i7 = 0; i7 < length; i7++) {
            denseMatrix2.set(i7, 0, cov[iArr[i7]][i]);
        }
        DenseMatrix solveLS = denseMatrix.solveLS(denseMatrix2);
        double d = mean;
        for (int i8 = 0; i8 < length; i8++) {
            linearRegressionModel.beta[i8 + 1] = solveLS.get(i8, 0);
            d -= dArr[i8] * linearRegressionModel.beta[i8 + 1];
        }
        linearRegressionModel.beta[0] = d;
        double variance = summaryResultTable.col(str).variance() * (summaryResultTable.col(str).count - 1);
        double d2 = linearRegressionModel.beta[0] - mean;
        double d3 = Criteria.INVALID_GAIN + (d2 * d2 * j);
        for (int i9 = 0; i9 < length; i9++) {
            d3 += 2.0d * d2 * summaryResultTable.col(iArr[i9]).sum * linearRegressionModel.beta[i9 + 1];
        }
        for (int i10 = 0; i10 < length; i10++) {
            for (int i11 = 0; i11 < length; i11++) {
                d3 += linearRegressionModel.beta[i10 + 1] * linearRegressionModel.beta[i11 + 1] * ((cov[iArr[i10]][iArr[i11]] * (j - 1)) + (summaryResultTable.col(iArr[i10]).mean() * summaryResultTable.col(iArr[i11]).mean() * j));
            }
        }
        linearRegressionModel.SST = variance;
        linearRegressionModel.SSR = d3;
        linearRegressionModel.SSE = variance - d3;
        linearRegressionModel.dfSST = j - 1;
        linearRegressionModel.dfSSR = length;
        linearRegressionModel.dfSSE = (j - length) - 1;
        linearRegressionModel.R2 = Math.max(Criteria.INVALID_GAIN, Math.min(1.0d, linearRegressionModel.SSR / linearRegressionModel.SST));
        linearRegressionModel.R = Math.sqrt(linearRegressionModel.R2);
        linearRegressionModel.MST = linearRegressionModel.SST / linearRegressionModel.dfSST;
        linearRegressionModel.MSR = linearRegressionModel.SSR / linearRegressionModel.dfSSR;
        linearRegressionModel.MSE = linearRegressionModel.SSE / linearRegressionModel.dfSSE;
        linearRegressionModel.Ra2 = 1.0d - (linearRegressionModel.MSE / linearRegressionModel.MST);
        linearRegressionModel.s = Math.sqrt(linearRegressionModel.MSE);
        linearRegressionModel.F = linearRegressionModel.MSR / linearRegressionModel.MSE;
        if (linearRegressionModel.F < Criteria.INVALID_GAIN) {
            linearRegressionModel.F = Criteria.INVALID_GAIN;
        }
        linearRegressionModel.AIC = (j * Math.log(linearRegressionModel.SSE)) + (2 * length);
        denseMatrix.scaleEqual(j - 1);
        DenseMatrix solveLS2 = denseMatrix.solveLS(JMatrixFunc.identity(denseMatrix.numRows(), denseMatrix.numRows()));
        for (int i12 = 0; i12 < length; i12++) {
            linearRegressionModel.FX[i12] = (linearRegressionModel.beta[i12 + 1] * linearRegressionModel.beta[i12 + 1]) / (linearRegressionModel.MSE * solveLS2.get(i12, i12));
            linearRegressionModel.TX[i12] = linearRegressionModel.beta[i12 + 1] / (linearRegressionModel.s * Math.sqrt(solveLS2.get(i12, i12)));
        }
        try {
            linearRegressionModel.pEquation = 1.0d - CDF.F(linearRegressionModel.F, strArr.length, (j - r0) - 1);
            linearRegressionModel.pX = new double[length];
            for (int i13 = 0; i13 < length; i13++) {
                linearRegressionModel.pX[i13] = (1.0d - CDF.studentT(Math.abs(linearRegressionModel.TX[i13]), (j - r0) - 1)) * 2.0d;
            }
        } catch (Exception e) {
        }
        return linearRegressionModel;
    }
}
