package com.alibaba.alink.operator.common.optim.subfunc;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/subfunc/UpdateModel.class */
public class UpdateModel extends ComputeFunction {
    private static final double EPS = 1.0E-18d;
    private static final Logger LOG = LoggerFactory.getLogger(UpdateModel.class);
    private static final long serialVersionUID = -5794349014922265691L;
    private final int numSearchStep;
    private final LinearTrainParams.OptimMethod method;
    private final String gradName;
    private final Params params;

    public UpdateModel(Params params, String str, LinearTrainParams.OptimMethod optimMethod, int i) {
        this.method = optimMethod;
        this.gradName = str;
        this.params = params;
        this.numSearchStep = i;
    }

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        double d;
        double[] dArr = (double[]) comContext.getObj("lossAllReduce");
        Tuple2<DenseVector, double[]> tuple2 = (Tuple2) comContext.getObj(OptimVariable.dir);
        Tuple2 tuple22 = (Tuple2) comContext.getObj(OptimVariable.pseGrad);
        Tuple2 tuple23 = (Tuple2) comContext.getObj(OptimVariable.grad);
        Tuple2<DenseVector, Double> tuple24 = (Tuple2) comContext.getObj(OptimVariable.currentCoef);
        Tuple2<DenseVector, Double> tuple25 = (Tuple2) comContext.getObj(OptimVariable.minCoef);
        double d2 = 1.0d;
        double[] dArr2 = (double[]) comContext.getObj(OptimVariable.convergenceInfo);
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / ((double[]) tuple2.f1)[0];
        }
        int i3 = -1;
        for (int i4 = 0; i4 < dArr.length; i4++) {
            if (dArr[i4] < dArr[0]) {
                dArr[0] = dArr[i4];
                i3 = i4;
            }
        }
        double d3 = ((double[]) tuple2.f1)[1] / this.numSearchStep;
        if (i3 == -1) {
            d = 0.0d;
            double[] dArr3 = (double[]) tuple2.f1;
            dArr3[1] = dArr3[1] * (1.0d / (this.numSearchStep * this.numSearchStep));
            tuple24.f1 = Double.valueOf(dArr[0]);
        } else if (i3 == this.numSearchStep) {
            d = d3 * i3;
            double[] dArr4 = (double[]) tuple2.f1;
            dArr4[1] = dArr4[1] * this.numSearchStep;
            ((double[]) tuple2.f1)[1] = Math.min(((double[]) tuple2.f1)[1], this.numSearchStep);
            d2 = Math.abs((((Double) tuple24.f1).doubleValue() - dArr[i3]) / ((Double) tuple24.f1).doubleValue());
            tuple24.f1 = Double.valueOf(dArr[this.numSearchStep]);
        } else {
            d = d3 * i3;
            d2 = Math.abs((((Double) tuple24.f1).doubleValue() - dArr[i3]) / ((Double) tuple24.f1).doubleValue());
            tuple24.f1 = Double.valueOf(dArr[i3]);
        }
        int stepNo = (comContext.getStepNo() - 1) * 3;
        dArr2[stepNo] = ((Double) tuple24.f1).doubleValue();
        if (tuple23 != null) {
            dArr2[stepNo + 1] = ((DenseVector) tuple23.f0).normL2();
        } else {
            dArr2[stepNo + 1] = ((DenseVector) tuple2.f0).normL2();
        }
        dArr2[stepNo + 2] = ((double[]) tuple2.f1)[1];
        if (this.method.equals(LinearTrainParams.OptimMethod.OWLQN)) {
            Tuple2 tuple26 = (Tuple2) comContext.getObj(OptimVariable.sKyK);
            int size = ((DenseVector) tuple2.f0).size();
            int stepNo2 = comContext.getStepNo() - 1;
            DenseVector[] denseVectorArr = (DenseVector[]) tuple26.f0;
            for (int i5 = 0; i5 < size; i5++) {
                double d4 = ((DenseVector) tuple24.f0).get(i5);
                double d5 = d4 - (((DenseVector) tuple2.f0).get(i5) * d);
                if (Math.abs(d4) > Criteria.INVALID_GAIN) {
                    if (d5 * d4 < Criteria.INVALID_GAIN) {
                        d5 = 0.0d;
                    }
                } else if (d5 * ((DenseVector) tuple22.f0).get(i5) > Criteria.INVALID_GAIN) {
                    d5 = 0.0d;
                }
                denseVectorArr[stepNo2 % 10].set(i5, d5 - d4);
                ((DenseVector) tuple24.f0).set(i5, d5);
            }
        } else if (this.method.equals(LinearTrainParams.OptimMethod.LBFGS)) {
            Tuple2 tuple27 = (Tuple2) comContext.getObj(OptimVariable.sKyK);
            int size2 = ((DenseVector) tuple2.f0).size();
            int stepNo3 = comContext.getStepNo() - 1;
            DenseVector[] denseVectorArr2 = (DenseVector[]) tuple27.f0;
            for (int i6 = 0; i6 < size2; i6++) {
                denseVectorArr2[stepNo3 % 10].set(i6, ((DenseVector) tuple2.f0).get(i6) * (-d));
            }
            ((DenseVector) tuple24.f0).plusScaleEqual((Vector) tuple2.f0, -d);
        } else {
            ((DenseVector) tuple24.f0).plusScaleEqual((Vector) tuple2.f0, -d);
        }
        if (((Double) tuple24.f1).doubleValue() < ((Double) tuple25.f1).doubleValue()) {
            tuple25.f1 = tuple24.f1;
            ((DenseVector) tuple25.f0).setEqual((DenseVector) tuple24.f0);
        }
        filter(tuple2, tuple24, tuple25, comContext, d2);
    }

    public void filter(Tuple2<DenseVector, double[]> tuple2, Tuple2<DenseVector, Double> tuple22, Tuple2<DenseVector, Double> tuple23, ComContext comContext, double d) {
        double doubleValue = ((Double) this.params.get(LinearTrainParams.EPSILON)).doubleValue();
        int intValue = ((Integer) this.params.get(LinearTrainParams.MAX_ITER)).intValue();
        double normL2 = ((DenseVector) ((Tuple2) comContext.getObj(this.gradName)).f0).normL2();
        if (Double.isNaN(normL2) || Double.isInfinite(normL2)) {
            throw new AkIllegalDataException("Optimize method not converged, may be your input has NaN or infinite number, or some train parameters not appropriate, or optimize method not suit for your data!");
        }
        if (((Double) tuple22.f1).doubleValue() < doubleValue || normL2 < doubleValue) {
            printLog(" method converged at step : ", ((Double) tuple22.f1).doubleValue(), ((Double) tuple23.f1).doubleValue(), ((double[]) tuple2.f1)[1], normL2, comContext, d);
            ((double[]) tuple2.f1)[0] = -1.0d;
            return;
        }
        if (comContext.getStepNo() > intValue - 1) {
            printLog(" method stop at max step : ", ((Double) tuple22.f1).doubleValue(), ((Double) tuple23.f1).doubleValue(), ((double[]) tuple2.f1)[1], normL2, comContext, d);
            ((double[]) tuple2.f1)[0] = -1.0d;
        } else if (((double[]) tuple2.f1)[1] < EPS) {
            printLog(" learning rate is too small, method stops at step : ", ((Double) tuple22.f1).doubleValue(), ((Double) tuple23.f1).doubleValue(), ((double[]) tuple2.f1)[1], normL2, comContext, d);
            ((double[]) tuple2.f1)[0] = -1.0d;
        } else if (d >= doubleValue || normL2 >= Math.sqrt(doubleValue)) {
            printLog(" method continue at step : ", ((Double) tuple22.f1).doubleValue(), ((Double) tuple23.f1).doubleValue(), ((double[]) tuple2.f1)[1], normL2, comContext, d);
        } else {
            printLog(" loss change ratio is too small, method stops at step : ", ((Double) tuple22.f1).doubleValue(), ((Double) tuple23.f1).doubleValue(), ((double[]) tuple2.f1)[1], normL2, comContext, d);
            ((double[]) tuple2.f1)[0] = -1.0d;
        }
    }

    private void printLog(String str, double d, double d2, double d3, double d4, ComContext comContext, double d5) {
        if (comContext.getTaskId() == 0 && AlinkGlobalConfiguration.isPrintProcessInfo()) {
            System.out.println(this.method.toString() + str + comContext.getStepNo() + " cur loss : " + d + " min loss : " + d2 + " grad norm : " + d4 + " learning rate : " + d3 + " loss change ratio : " + d5);
        }
        LOG.info(this.method.toString() + str + ": {}, cur loss: {}, min loss: {}, grad norm: {}, learning rate: {}", new Object[]{Integer.valueOf(comContext.getStepNo()), Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d4), Double.valueOf(d3)});
    }
}
