package com.alibaba.alink.operator.common.timeseries.arma;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.operator.common.timeseries.AbstractGradientTarget;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.ArrayList;

/* loaded from: input_file:com/alibaba/alink/operator/common/timeseries/arma/CSSGradientTarget.class */
public class CSSGradientTarget extends AbstractGradientTarget {
    public double[] data;
    public double[] cResidual;
    public int p;
    public int q;
    public int ifIntercept;
    public int optOrder;
    public int type;
    public double[] iterResidual;
    public double css;

    public void fit(ArrayList<double[]> arrayList, double[] dArr, double[] dArr2, int i, int i2, int i3) {
        double[][] dArr3;
        double[] dArr4 = arrayList.get(0);
        double[] dArr5 = arrayList.get(1);
        double d = arrayList.get(2)[0];
        this.p = dArr4.length;
        this.q = dArr5.length;
        this.data = dArr;
        this.cResidual = dArr2;
        this.type = i2;
        this.optOrder = i;
        double[][] dArr6 = new double[dArr.length][1];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr6[i4][0] = dArr[i4];
        }
        super.setX(new DenseMatrix(dArr6));
        this.ifIntercept = i3;
        if (i3 == 0) {
            dArr3 = new double[dArr4.length + dArr5.length][1];
        } else {
            dArr3 = new double[dArr4.length + dArr5.length + 1][1];
            dArr3[dArr3.length - 1][0] = d;
        }
        for (int i5 = 0; i5 < dArr4.length; i5++) {
            dArr3[i5][0] = dArr4[i5];
        }
        for (int i6 = 0; i6 < dArr5.length; i6++) {
            dArr3[i6 + dArr4.length][0] = dArr5[i6];
        }
        if (dArr3.length == 0) {
            this.initCoef = DenseMatrix.zeros(0, 0);
        } else {
            this.initCoef = new DenseMatrix(dArr3);
        }
    }

    public double oneRSS(int i, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double d) {
        double d2 = dArr[i];
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            d2 -= dArr3[i2] * dArr[(i - i2) - 1];
        }
        double d3 = 0.0d;
        int length = dArr4.length;
        if (i < dArr4.length) {
            length = i;
        }
        for (int i3 = 0; i3 < length; i3++) {
            d3 += dArr4[i3] * dArr2[(i - i3) - 1];
        }
        return (d2 - d3) - d;
    }

    public double computeRSS(double[] dArr, double[] dArr2, int i, double[] dArr3, double[] dArr4, double d, int i2, int i3) {
        this.iterResidual = (double[]) dArr2.clone();
        int length = dArr3.length + i;
        if (i2 == 1) {
            length = dArr3.length;
        }
        double d2 = 0.0d;
        for (int i4 = length; i4 < dArr.length; i4++) {
            double oneRSS = i3 == 1 ? oneRSS(i4, dArr, this.iterResidual, dArr3, dArr4, d) : oneRSS(i4, dArr, this.iterResidual, dArr3, dArr4, Criteria.INVALID_GAIN);
            if (i2 == 1) {
                this.iterResidual[i4] = oneRSS;
            }
            d2 += oneRSS * oneRSS;
        }
        return d2;
    }

    public double pComputeRSS(int i, double[] dArr, double[] dArr2, int i2, double[] dArr3, double[] dArr4, double d, int i3, int i4) {
        int max = Math.max(dArr3.length + i2, dArr4.length + i2);
        if (i3 == 1) {
            max = Math.max(dArr3.length, dArr4.length);
        }
        double d2 = 0.0d;
        if (i4 == 1) {
            for (int i5 = max; i5 < dArr.length; i5++) {
                d2 -= (2.0d * oneRSS(i5, dArr, dArr2, dArr3, dArr4, d)) * dArr[(i5 - i) - 1];
            }
        }
        if (i4 == 2) {
            for (int i6 = max; i6 < dArr.length; i6++) {
                d2 -= (2.0d * oneRSS(i6, dArr, dArr2, dArr3, dArr4, d)) * dArr2[(i6 - i) - 1];
            }
        }
        if (i4 == 3) {
            for (int i7 = max; i7 < dArr.length; i7++) {
                d2 -= 2.0d * oneRSS(i7, dArr, dArr2, dArr3, dArr4, d);
            }
        }
        return d2;
    }

    @Override // com.alibaba.alink.operator.common.timeseries.AbstractGradientTarget
    public DenseMatrix gradient(DenseMatrix denseMatrix, int i) {
        if (denseMatrix.numRows() != this.p + this.q && denseMatrix.numRows() != this.p + this.q + 1) {
            throw new AkIllegalDataException("coef is not comparable with the model.");
        }
        double[] dArr = new double[this.p];
        double[] dArr2 = new double[this.q];
        for (int i2 = 0; i2 < this.p; i2++) {
            dArr[i2] = denseMatrix.get(i2, 0);
        }
        for (int i3 = 0; i3 < this.q; i3++) {
            dArr2[i3] = denseMatrix.get(i3 + this.p, 0);
        }
        double d = this.ifIntercept == 1 ? denseMatrix.get(this.p + this.q, 0) : 0.0d;
        computeRSS(this.data, this.cResidual, this.optOrder, dArr, dArr2, d, this.type, this.ifIntercept);
        double[][] dArr3 = this.ifIntercept == 1 ? new double[this.p + this.q + 1][1] : new double[this.p + this.q][1];
        for (int i4 = 0; i4 < this.p; i4++) {
            dArr3[i4][0] = pComputeRSS(i4, this.data, this.iterResidual, this.optOrder, dArr, dArr2, d, this.type, 1);
        }
        for (int i5 = 0; i5 < this.q; i5++) {
            dArr3[i5 + this.p][0] = pComputeRSS(i5, this.data, this.iterResidual, this.optOrder, dArr, dArr2, d, this.type, 2);
        }
        if (this.ifIntercept == 1) {
            dArr3[this.p + this.q][0] = pComputeRSS(0, this.data, this.iterResidual, this.optOrder, dArr, dArr2, d, this.type, 3);
        }
        return new DenseMatrix(dArr3);
    }

    @Override // com.alibaba.alink.operator.common.timeseries.AbstractGradientTarget
    public double f(DenseMatrix denseMatrix) {
        double[] dArr = new double[this.p];
        double[] dArr2 = new double[this.q];
        for (int i = 0; i < this.p; i++) {
            dArr[i] = denseMatrix.get(i, 0);
        }
        for (int i2 = 0; i2 < this.q; i2++) {
            dArr2[i2] = denseMatrix.get(i2 + this.p, 0);
        }
        this.css = computeRSS(this.data, this.cResidual, this.optOrder, dArr, dArr2, this.ifIntercept == 1 ? denseMatrix.get(this.p + this.q, 0) : 0.0d, this.type, this.ifIntercept);
        this.residual = this.iterResidual;
        return this.css;
    }
}
