package com.alibaba.alink.operator.common.regression;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.jama.JMatrixFunc;
import com.alibaba.alink.common.probabilistic.CDF;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.ArrayList;

/* loaded from: input_file:com/alibaba/alink/operator/common/regression/RidgeRegressionProcess.class */
public class RidgeRegressionProcess {
    private String nameY;
    private String[] nameX;
    private SummaryResultTable srt;

    public RidgeRegressionProcess(SummaryResultTable summaryResultTable, String str, String[] strArr) {
        this.nameY = null;
        this.nameX = null;
        this.srt = null;
        this.srt = summaryResultTable;
        this.nameY = str;
        if (null != strArr) {
            this.nameX = new String[strArr.length];
            System.arraycopy(strArr, 0, this.nameX, 0, strArr.length);
        }
    }

    public static RidgeRegressionProcessResult calc(SummaryResultTable summaryResultTable, String str, String[] strArr, double[] dArr) throws Exception {
        return new RidgeRegressionProcess(summaryResultTable, str, strArr).calc(dArr);
    }

    public RidgeRegressionProcessResult calc(double[] dArr) throws Exception {
        String[] strArr = this.srt.colNames;
        Class[] clsArr = new Class[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            clsArr[i] = this.srt.col(i).dataType;
        }
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(strArr, this.nameY);
        Class cls = clsArr[findColIndexWithAssertAndHint];
        if (cls != Double.class && cls != Long.class && cls != Integer.class) {
            throw new Exception("col type must be double or bigint!");
        }
        if (this.nameX.length == 0) {
            throw new Exception("nameX must input!");
        }
        for (int i2 = 0; i2 < this.nameX.length; i2++) {
            Class cls2 = clsArr[TableUtil.findColIndexWithAssertAndHint(strArr, this.nameX[i2])];
            if (cls2 != Double.class && cls2 != Long.class && cls2 != Integer.class) {
                throw new Exception("col type must be double or bigint!");
            }
        }
        int length = this.nameX.length;
        int[] iArr = new int[length];
        for (int i3 = 0; i3 < length; i3++) {
            iArr[i3] = TableUtil.findColIndexWithAssert(this.srt.colNames, this.nameX[i3]);
        }
        if (this.srt.col(findColIndexWithAssertAndHint).countMissValue > 0 || this.srt.col(findColIndexWithAssertAndHint).countNanValue > 0) {
            throw new Exception("col " + this.nameY + " has null value or nan value!");
        }
        for (int i4 = 0; i4 < iArr.length; i4++) {
            if (this.srt.col(iArr[i4]).countMissValue > 0 || this.srt.col(iArr[i4]).countNanValue > 0) {
                throw new Exception("col " + this.nameX[i4] + " has null value or nan value!");
            }
        }
        if (this.srt.col(0).countTotal == 0) {
            throw new Exception("table is empty!");
        }
        if (this.srt.col(0).countTotal < this.nameX.length) {
            throw new Exception("record size Less than features size!");
        }
        long j = this.srt.col(findColIndexWithAssertAndHint).count;
        if (j == 0) {
            throw new Exception("Y valid value num is zero!");
        }
        ArrayList arrayList = new ArrayList();
        for (int i5 = 0; i5 < iArr.length; i5++) {
            if (this.srt.col(iArr[i5]).count != 0) {
                arrayList.add(this.nameX[i5]);
            }
        }
        arrayList.toArray(this.nameX);
        double[] dArr2 = new double[length];
        for (int i6 = 0; i6 < length; i6++) {
            dArr2[i6] = this.srt.col(iArr[i6]).mean();
        }
        double mean = this.srt.col(findColIndexWithAssertAndHint).mean();
        double[][] cov = this.srt.getCov();
        DenseMatrix denseMatrix = new DenseMatrix(length, 1);
        for (int i7 = 0; i7 < length; i7++) {
            denseMatrix.set(i7, 0, cov[iArr[i7]][findColIndexWithAssertAndHint]);
        }
        RidgeRegressionProcessResult ridgeRegressionProcessResult = new RidgeRegressionProcessResult(dArr.length);
        for (int i8 = 0; i8 < dArr.length; i8++) {
            double d = dArr[i8];
            ridgeRegressionProcessResult.kVals[i8] = d;
            ridgeRegressionProcessResult.lrModels[i8] = new LinearRegressionModel(j, this.nameY, this.nameX);
            DenseMatrix denseMatrix2 = new DenseMatrix(length, length);
            for (int i9 = 0; i9 < length; i9++) {
                for (int i10 = 0; i10 < length; i10++) {
                    if (i9 == i10) {
                        denseMatrix2.set(i9, i10, cov[iArr[i9]][iArr[i10]] + d);
                    } else {
                        denseMatrix2.set(i9, i10, cov[iArr[i9]][iArr[i10]]);
                    }
                }
            }
            DenseMatrix solveLS = denseMatrix2.solveLS(denseMatrix);
            double d2 = mean;
            for (int i11 = 0; i11 < length; i11++) {
                ridgeRegressionProcessResult.lrModels[i8].beta[i11 + 1] = solveLS.get(i11, 0);
                d2 -= dArr2[i11] * ridgeRegressionProcessResult.lrModels[i8].beta[i11 + 1];
            }
            ridgeRegressionProcessResult.lrModels[i8].beta[0] = d2;
            double variance = this.srt.col(this.nameY).variance() * (this.srt.col(this.nameY).count - 1);
            double d3 = ridgeRegressionProcessResult.lrModels[i8].beta[0] - mean;
            double d4 = Criteria.INVALID_GAIN + (d3 * d3 * j);
            for (int i12 = 0; i12 < length; i12++) {
                d4 += 2.0d * d3 * this.srt.col(iArr[i12]).sum * ridgeRegressionProcessResult.lrModels[i8].beta[i12 + 1];
            }
            for (int i13 = 0; i13 < length; i13++) {
                for (int i14 = 0; i14 < length; i14++) {
                    d4 += ridgeRegressionProcessResult.lrModels[i8].beta[i13 + 1] * ridgeRegressionProcessResult.lrModels[i8].beta[i14 + 1] * ((cov[iArr[i13]][iArr[i14]] * (j - 1)) + (this.srt.col(iArr[i13]).mean() * this.srt.col(iArr[i14]).mean() * j));
                }
            }
            ridgeRegressionProcessResult.lrModels[i8].SST = variance;
            ridgeRegressionProcessResult.lrModels[i8].SSR = d4;
            ridgeRegressionProcessResult.lrModels[i8].SSE = variance - d4;
            ridgeRegressionProcessResult.lrModels[i8].dfSST = j - 1;
            ridgeRegressionProcessResult.lrModels[i8].dfSSR = length;
            ridgeRegressionProcessResult.lrModels[i8].dfSSE = (j - length) - 1;
            ridgeRegressionProcessResult.lrModels[i8].R2 = Math.max(Criteria.INVALID_GAIN, Math.min(1.0d, ridgeRegressionProcessResult.lrModels[i8].SSR / ridgeRegressionProcessResult.lrModels[i8].SST));
            ridgeRegressionProcessResult.lrModels[i8].R = Math.sqrt(ridgeRegressionProcessResult.lrModels[i8].R2);
            ridgeRegressionProcessResult.lrModels[i8].MST = ridgeRegressionProcessResult.lrModels[i8].SST / ridgeRegressionProcessResult.lrModels[i8].dfSST;
            ridgeRegressionProcessResult.lrModels[i8].MSR = ridgeRegressionProcessResult.lrModels[i8].SSR / ridgeRegressionProcessResult.lrModels[i8].dfSSR;
            ridgeRegressionProcessResult.lrModels[i8].MSE = ridgeRegressionProcessResult.lrModels[i8].SSE / ridgeRegressionProcessResult.lrModels[i8].dfSSE;
            ridgeRegressionProcessResult.lrModels[i8].Ra2 = 1.0d - (ridgeRegressionProcessResult.lrModels[i8].MSE / ridgeRegressionProcessResult.lrModels[i8].MST);
            ridgeRegressionProcessResult.lrModels[i8].s = Math.sqrt(ridgeRegressionProcessResult.lrModels[i8].MSE);
            ridgeRegressionProcessResult.lrModels[i8].F = ridgeRegressionProcessResult.lrModels[i8].MSR / ridgeRegressionProcessResult.lrModels[i8].MSE;
            if (ridgeRegressionProcessResult.lrModels[i8].F < Criteria.INVALID_GAIN) {
                ridgeRegressionProcessResult.lrModels[i8].F = Criteria.INVALID_GAIN;
            }
            ridgeRegressionProcessResult.lrModels[i8].AIC = (j * Math.log(ridgeRegressionProcessResult.lrModels[i8].SSE)) + (2 * length);
            denseMatrix2.scaleEqual(j - 1);
            DenseMatrix solveLS2 = denseMatrix2.solveLS(JMatrixFunc.identity(denseMatrix2.numRows(), denseMatrix2.numRows()));
            for (int i15 = 0; i15 < length; i15++) {
                ridgeRegressionProcessResult.lrModels[i8].FX[i15] = (ridgeRegressionProcessResult.lrModels[i8].beta[i15 + 1] * ridgeRegressionProcessResult.lrModels[i8].beta[i15 + 1]) / (ridgeRegressionProcessResult.lrModels[i8].MSE * solveLS2.get(i15, i15));
                ridgeRegressionProcessResult.lrModels[i8].TX[i15] = ridgeRegressionProcessResult.lrModels[i8].beta[i15 + 1] / (ridgeRegressionProcessResult.lrModels[i8].s * Math.sqrt(solveLS2.get(i15, i15)));
            }
            ridgeRegressionProcessResult.lrModels[i8].pEquation = 1.0d - CDF.F(ridgeRegressionProcessResult.lrModels[i8].F, this.nameX.length, (j - r0) - 1);
            ridgeRegressionProcessResult.lrModels[i8].pX = new double[length];
            for (int i16 = 0; i16 < length; i16++) {
                ridgeRegressionProcessResult.lrModels[i8].pX[i16] = (1.0d - CDF.studentT(Math.abs(ridgeRegressionProcessResult.lrModels[i8].TX[i16]), (j - r0) - 1)) * 2.0d;
            }
        }
        return ridgeRegressionProcessResult;
    }
}
