package com.alibaba.alink.operator.common.optim;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompareCriterionFunction;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.optim.subfunc.IterTermination;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.optim.subfunc.OutputModel;
import com.alibaba.alink.operator.common.optim.subfunc.ParseRowModel;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateCoefficient;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateConvergenceInfo;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateMatrix;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateVector;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import java.util.List;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/Newton.class */
public class Newton extends Optimizer {
    private static final int MAX_FEATURE_NUM = 1000;

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/Newton$CalcGradientAndHessian.class */
    public static class CalcGradientAndHessian extends ComputeFunction {
        private static final long serialVersionUID = -8884412894729352793L;
        private OptimObjFunc objFunc;

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            Iterable<Tuple3<Double, Double, Vector>> iterable = (Iterable) comContext.getObj("trainData");
            Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.dir);
            Tuple2 tuple22 = (Tuple2) comContext.getObj(OptimVariable.currentCoef);
            DenseMatrix denseMatrix = (DenseMatrix) comContext.getObj(OptimVariable.hessian);
            int size = ((DenseVector) tuple2.f0).size();
            if (this.objFunc == null) {
                this.objFunc = (OptimObjFunc) ((List) comContext.getObj(OptimVariable.objFunc)).get(0);
            }
            Tuple2<Double, Double> calcHessianGradientLoss = this.objFunc.calcHessianGradientLoss(iterable, (DenseVector) tuple22.f0, denseMatrix, (DenseVector) tuple2.f0);
            double[] dArr = (double[]) comContext.getObj(OptimVariable.gradHessAllReduce);
            if (dArr == null) {
                dArr = new double[size + (size * size) + 2];
                comContext.putObj(OptimVariable.gradHessAllReduce, dArr);
            }
            for (int i = 0; i < size; i++) {
                dArr[i] = ((DenseVector) tuple2.f0).get(i);
                for (int i2 = 0; i2 < size; i2++) {
                    dArr[((i + 1) * size) + i2] = denseMatrix.get(i, i2);
                }
            }
            dArr[size + (size * size)] = ((Double) calcHessianGradientLoss.f0).doubleValue();
            dArr[size + (size * size) + 1] = ((Double) calcHessianGradientLoss.f1).doubleValue();
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/Newton$GetGradeintAndHessian.class */
    public static class GetGradeintAndHessian extends ComputeFunction {
        private static final long serialVersionUID = 7381646408838896976L;

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.dir);
            DenseMatrix denseMatrix = (DenseMatrix) comContext.getObj(OptimVariable.hessian);
            int size = ((DenseVector) tuple2.f0).size();
            double[] dArr = (double[]) comContext.getObj(OptimVariable.gradHessAllReduce);
            ((double[]) tuple2.f1)[0] = dArr[size + (size * size)];
            for (int i = 0; i < size; i++) {
                ((DenseVector) tuple2.f0).set(i, dArr[i] / ((double[]) tuple2.f1)[0]);
                for (int i2 = 0; i2 < size; i2++) {
                    denseMatrix.set(i, i2, dArr[((i + 1) * size) + i2] / ((double[]) tuple2.f1)[0]);
                }
            }
            ((double[]) tuple2.f1)[0] = dArr[size + (size * size)];
            ((double[]) tuple2.f1)[1] = dArr[(size + (size * size)) + 1] / ((double[]) tuple2.f1)[0];
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/Newton$UpdateModel.class */
    public static class UpdateModel extends ComputeFunction {
        private static final Logger LOG = LoggerFactory.getLogger(UpdateModel.class);
        private static final long serialVersionUID = -4113558902964352141L;
        private DenseMatrix bMat;
        private final double epsilon;
        private final int maxIter;

        private UpdateModel(int i, double d) {
            this.bMat = null;
            this.maxIter = i;
            this.epsilon = d;
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            Tuple2<DenseVector, double[]> tuple2 = (Tuple2) comContext.getObj(OptimVariable.dir);
            Tuple2<DenseVector, Double> tuple22 = (Tuple2) comContext.getObj(OptimVariable.currentCoef);
            Tuple2<DenseVector, Double> tuple23 = (Tuple2) comContext.getObj(OptimVariable.minCoef);
            DenseMatrix denseMatrix = (DenseMatrix) comContext.getObj(OptimVariable.hessian);
            double normL2 = ((DenseVector) tuple2.f0).normL2();
            double normL1 = 1.0d / ((DenseVector) tuple2.f0).normL1();
            denseMatrix.scaleEqual(normL1);
            ((DenseVector) tuple2.f0).scaleEqual(normL1);
            int size = ((DenseVector) tuple2.f0).size();
            if (this.bMat == null) {
                this.bMat = new DenseMatrix(size, 1);
            }
            for (int i = 0; i < size; i++) {
                this.bMat.set(i, 0, ((DenseVector) tuple2.f0).get(i));
            }
            DenseMatrix solveLS = denseMatrix.solveLS(this.bMat);
            for (int i2 = 0; i2 < size; i2++) {
                ((DenseVector) tuple2.f0).set(i2, solveLS.get(i2, 0));
            }
            ((DenseVector) tuple22.f0).minusEqual((Vector) tuple2.f0);
            tuple22.f1 = Double.valueOf(((double[]) tuple2.f1)[1]);
            if (((Double) tuple22.f1).doubleValue() < ((Double) tuple23.f1).doubleValue()) {
                tuple23.f1 = tuple22.f1;
                for (int i3 = 0; i3 < size; i3++) {
                    ((DenseVector) tuple23.f0).set(i3, ((DenseVector) tuple22.f0).get(i3));
                }
            }
            filter(tuple2, tuple22, tuple23, normL2, comContext);
        }

        public void filter(Tuple2<DenseVector, double[]> tuple2, Tuple2<DenseVector, Double> tuple22, Tuple2<DenseVector, Double> tuple23, double d, ComContext comContext) {
            if (((Double) tuple22.f1).doubleValue() < this.epsilon || d < this.epsilon) {
                printLog(" method converged at step : ", ((Double) tuple22.f1).doubleValue(), ((Double) tuple23.f1).doubleValue(), d, comContext);
                ((double[]) tuple2.f1)[0] = -1.0d;
            } else if (comContext.getStepNo() <= this.maxIter - 1) {
                printLog(" method continue at step : ", ((Double) tuple22.f1).doubleValue(), ((Double) tuple23.f1).doubleValue(), d, comContext);
            } else {
                printLog(" method stop at max step : ", ((Double) tuple22.f1).doubleValue(), ((Double) tuple23.f1).doubleValue(), d, comContext);
                ((double[]) tuple2.f1)[0] = -1.0d;
            }
        }

        private void printLog(String str, double d, double d2, double d3, ComContext comContext) {
            if (comContext.getTaskId() == 0 && AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println("Newton" + str + comContext.getStepNo() + " cur loss : " + d + " min loss : " + d2 + " grad norm : " + d3);
            }
            LOG.info("Newton" + str + ": {}, cur loss: {}, min loss: {}, grad norm: {}", new Object[]{Integer.valueOf(comContext.getStepNo()), Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3)});
        }
    }

    public Newton(DataSet<OptimObjFunc> dataSet, DataSet<Tuple3<Double, Double, Vector>> dataSet2, DataSet<Integer> dataSet3, Params params) {
        super(dataSet, dataSet2, dataSet3, params);
    }

    @Override // com.alibaba.alink.operator.common.optim.Optimizer
    public DataSet<Tuple2<DenseVector, double[]>> optimize() {
        int intValue = ((Integer) this.params.get(LinearTrainParams.MAX_ITER)).intValue();
        double doubleValue = ((Double) this.params.get(LinearTrainParams.EPSILON)).doubleValue();
        checkInitCoef();
        return new IterativeComQueue().initWithPartitionedData("trainData", this.trainData).initWithBroadcastData(OptimVariable.model, this.coefVec).initWithBroadcastData(OptimVariable.objFunc, this.objFuncSet).add(new PreallocateCoefficient(OptimVariable.currentCoef)).add(new PreallocateCoefficient(OptimVariable.minCoef)).add(new PreallocateConvergenceInfo(OptimVariable.convergenceInfo, intValue)).add(new PreallocateVector(OptimVariable.dir, new double[2])).add(new PreallocateMatrix(OptimVariable.hessian, 1000)).add(new CalcGradientAndHessian()).add(new AllReduce(OptimVariable.gradHessAllReduce)).add(new GetGradeintAndHessian()).add(new UpdateModel(intValue, doubleValue)).setCompareCriterionOfNode0((CompareCriterionFunction) new IterTermination()).closeWith(new OutputModel()).setMaxIter(intValue).exec().mapPartition(new ParseRowModel());
    }
}
