package com.alibaba.alink.operator.common.timeseries;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.ArrayList;

/* loaded from: input_file:com/alibaba/alink/operator/common/timeseries/BFGS.class */
public class BFGS {
    public static AbstractGradientTarget solve(AbstractGradientTarget abstractGradientTarget, int i, double d, double d2, int[] iArr, int i2) {
        ArrayList arrayList = new ArrayList();
        for (int i3 : iArr) {
            arrayList.add(Integer.valueOf(i3));
        }
        boolean contains = arrayList.contains(1);
        boolean contains2 = arrayList.contains(2);
        boolean contains3 = arrayList.contains(3);
        boolean z = false;
        boolean z2 = false;
        DenseMatrix eye = DenseMatrix.eye(abstractGradientTarget.getInitCoef().numRows());
        DenseMatrix m134clone = eye.m134clone();
        DenseMatrix initCoef = abstractGradientTarget.getInitCoef();
        initCoef.m134clone();
        DenseMatrix ifMeanGradient = ifMeanGradient(abstractGradientTarget, initCoef, 0, contains2);
        int i4 = 1;
        while (true) {
            if (i4 >= i) {
                break;
            }
            DenseMatrix scale = m134clone.multiplies(ifMeanGradient).scale(-1.0d);
            if (contains) {
                int i5 = 0;
                while (true) {
                    if (i5 >= scale.numRows()) {
                        break;
                    }
                    if (scale.get(i5, 0) > 1.0d) {
                        z = true;
                        break;
                    }
                    i5++;
                }
            }
            if (z) {
                scale = scale.scale(1.0d / norm2(scale));
            }
            DenseMatrix scale2 = scale.scale(backtrackLineSearch(abstractGradientTarget, initCoef, scale, ifMeanGradient, i2));
            DenseMatrix plus = initCoef.plus(scale2);
            if (i2 >= 0 && plus.numRows() > i2 && plus.get(i2, 0) < Criteria.INVALID_GAIN) {
                break;
            }
            DenseMatrix ifMeanGradient2 = ifMeanGradient(abstractGradientTarget, plus, i4, contains2);
            DenseMatrix minus = ifMeanGradient2.minus(ifMeanGradient);
            double d3 = minus.transpose().multiplies(scale2).get(0, 0);
            DenseMatrix plus2 = eye.minus(scale2.multiplies(minus.transpose()).scale(1.0d / d3)).multiplies(m134clone).multiplies(eye.minus(minus.multiplies(scale2.transpose()).scale(1.0d / d3))).plus(scale2.multiplies(scale2.transpose()).scale(1.0d / d3));
            int numRows = plus2.numRows() - 1;
            while (true) {
                if (numRows < 0) {
                    break;
                }
                if (Double.isNaN(plus2.get(numRows, 0))) {
                    z2 = true;
                    break;
                }
                numRows--;
            }
            if (z2) {
                abstractGradientTarget.setWarn("1");
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("Warn: Stop when Hessian Matrix contains NaN. Loss function CSS doesn't converge in gradient descent.");
                }
            } else {
                m134clone = plus2.m134clone();
                ifMeanGradient = ifMeanGradient2.m134clone();
                initCoef = plus.m134clone();
                if (norm2(ifMeanGradient2) < d) {
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println("Optimization terminated successfully.(Strong converge)");
                    }
                } else if (!contains3 || norm2(minus) >= d2) {
                    i4++;
                } else if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("Optimization terminated successfully.(Weak converge)");
                }
            }
        }
        abstractGradientTarget.setIter(i4);
        abstractGradientTarget.setFinalCoef(initCoef);
        abstractGradientTarget.setH(m134clone);
        abstractGradientTarget.setMinValue(abstractGradientTarget.f(initCoef));
        return abstractGradientTarget;
    }

    private static double backtrackLineSearch(AbstractGradientTarget abstractGradientTarget, DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseMatrix denseMatrix3, int i) {
        double d = 1.0d;
        DenseMatrix plus = denseMatrix.plus(denseMatrix2.scale(1.0d));
        double d2 = 0.0d;
        if (i >= 0 && plus.numRows() > i) {
            d2 = plus.get(i, 0);
        }
        for (int i2 = 10000; i2 > 0 && abstractGradientTarget.f(denseMatrix) - abstractGradientTarget.f(denseMatrix.plus(denseMatrix2.scale(d))) < denseMatrix2.transpose().multiplies(denseMatrix3).scale(d * 1.0E-5d).scale(-1.0d).get(0, 0) && d2 >= Criteria.INVALID_GAIN; i2--) {
            d *= 0.5d;
        }
        return d;
    }

    private static double norm2(DenseMatrix denseMatrix) {
        return denseMatrix.numCols() == 1 ? new DenseVector(denseMatrix.getColumn(0)).normL2() : denseMatrix.norm2();
    }

    private static DenseMatrix ifMeanGradient(AbstractGradientTarget abstractGradientTarget, DenseMatrix denseMatrix, int i, boolean z) {
        return z ? abstractGradientTarget.getSampleSize() == -99 ? abstractGradientTarget.gradient(denseMatrix, i).scale(1.0d / abstractGradientTarget.getX().numRows()) : abstractGradientTarget.gradient(denseMatrix, i).scale(1.0d / abstractGradientTarget.getSampleSize()) : abstractGradientTarget.gradient(denseMatrix, i);
    }
}
