package com.alibaba.alink.operator.common.optim.barrierIcq;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompareCriterionFunction;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.optim.Optimizer;
import com.alibaba.alink.operator.common.optim.activeSet.ConstraintObjFunc;
import com.alibaba.alink.operator.common.optim.activeSet.ConstraintVariable;
import com.alibaba.alink.operator.common.optim.activeSet.Sqp;
import com.alibaba.alink.operator.common.optim.activeSet.SqpPai;
import com.alibaba.alink.operator.common.optim.activeSet.SqpUtil;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateMatrix;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateVector;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs100;
import com.alibaba.alink.params.shared.linear.HasEpsilonDefaultAs0000001;
import com.alibaba.alink.params.shared.linear.HasL2;
import com.alibaba.alink.params.shared.linear.HasWithIntercept;
import java.util.List;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/barrierIcq/LogBarrier.class */
public class LogBarrier extends Optimizer {
    private static final int MAX_FEATURE_NUM = 3000;

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/barrierIcq/LogBarrier$CalcConvergence.class */
    public static class CalcConvergence extends ComputeFunction {
        private static final long serialVersionUID = 4453719204627742833L;

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            double abs;
            int intValue = ((Integer) comContext.getObj(BarrierVariable.localIterTime)).intValue();
            if (intValue == 0) {
                abs = 100.0d;
            } else {
                double doubleValue = ((Double) comContext.getObj(ConstraintVariable.lastLoss)).doubleValue();
                double doubleValue2 = ((Double) comContext.getObj(ConstraintVariable.loss)).doubleValue();
                abs = intValue <= 5 ? (doubleValue - doubleValue2) / (Math.abs(doubleValue2) * intValue) : (doubleValue - doubleValue2) / (Math.abs(doubleValue2) * 5);
            }
            comContext.putObj(BarrierVariable.localIterTime, Integer.valueOf(intValue + 1));
            comContext.putObj(ConstraintVariable.convergence, Double.valueOf(abs));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/barrierIcq/LogBarrier$InitializeParams.class */
    public static class InitializeParams extends ComputeFunction {
        private static final long serialVersionUID = 1857803292287152190L;

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            if (comContext.getStepNo() == 1) {
                double numRows = ((ConstraintObjFunc) ((List) comContext.getObj(OptimVariable.objFunc)).get(0)).inequalityConstraint.numRows();
                double d = numRows == Criteria.INVALID_GAIN ? 0.0d : 1.0d / numRows;
                comContext.putObj(BarrierVariable.t, Double.valueOf(numRows));
                comContext.putObj(BarrierVariable.divideT, Double.valueOf(d));
                comContext.putObj(BarrierVariable.localIterTime, 0);
                comContext.putObj(BarrierVariable.hessianNotConvergence, false);
                comContext.putObj(ConstraintVariable.newtonRetryTime, 12);
                comContext.putObj(ConstraintVariable.minL2Weight, Double.valueOf(1.0E-8d));
                comContext.putObj(ConstraintVariable.linearSearchTimes, 40);
                comContext.putObj(ConstraintVariable.weight, (DenseVector) ((List) comContext.getObj(OptimVariable.model)).get(0));
                comContext.putObj(ConstraintVariable.loss, Double.valueOf(Criteria.INVALID_GAIN));
                comContext.putObj(ConstraintVariable.lastLoss, Double.valueOf(Double.MAX_VALUE));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/barrierIcq/LogBarrier$IterTermination.class */
    public static class IterTermination extends CompareCriterionFunction {
        private static final long serialVersionUID = -3313706116254321792L;
        private int maxIter;
        private double epsilon;

        IterTermination(int i, double d) {
            this.maxIter = i;
            this.epsilon = d;
        }

        @Override // com.alibaba.alink.common.comqueue.CompareCriterionFunction
        public boolean calc(ComContext comContext) {
            if (((Double) comContext.getObj(ConstraintVariable.convergence)).doubleValue() >= this.epsilon && ((Integer) comContext.getObj(BarrierVariable.localIterTime)).intValue() < this.maxIter) {
                return false;
            }
            comContext.putObj(BarrierVariable.localIterTime, 0);
            double doubleValue = ((Double) comContext.getObj(BarrierVariable.t)).doubleValue() * 50.0d;
            double doubleValue2 = ((Double) comContext.getObj(BarrierVariable.divideT)).doubleValue() / 50.0d;
            comContext.putObj(BarrierVariable.t, Double.valueOf(doubleValue));
            comContext.putObj(BarrierVariable.divideT, Double.valueOf(doubleValue2));
            return doubleValue2 < this.epsilon;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/barrierIcq/LogBarrier$RunNewtonStep.class */
    public static class RunNewtonStep extends ComputeFunction {
        private static final long serialVersionUID = 4802057437164571355L;
        private boolean hasIntercept;
        private double l2;

        public RunNewtonStep(Params params) {
            this.hasIntercept = ((Boolean) params.get(HasWithIntercept.WITH_INTERCEPT)).booleanValue();
            this.l2 = ((Double) params.get(HasL2.L_2)).doubleValue();
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            ConstraintObjFunc constraintObjFunc = (ConstraintObjFunc) ((List) comContext.getObj(OptimVariable.objFunc)).get(0);
            int intValue = ((Integer) ((List) comContext.getObj(ConstraintVariable.weightDim)).get(0)).intValue();
            int i = this.hasIntercept ? 1 : 0;
            double doubleValue = ((Double) comContext.getObj(ConstraintVariable.minL2Weight)).doubleValue();
            Double d = (Double) comContext.getObj(ConstraintVariable.loss);
            Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.grad);
            DenseVector denseVector = (DenseVector) tuple2.f0;
            DenseMatrix denseMatrix = (DenseMatrix) comContext.getObj(OptimVariable.hessian);
            DenseVector denseVector2 = (DenseVector) comContext.getObj(ConstraintVariable.weight);
            int intValue2 = ((Integer) comContext.getObj(ConstraintVariable.newtonRetryTime)).intValue();
            double doubleValue2 = ((Double) comContext.getObj(BarrierVariable.t)).doubleValue();
            int size = constraintObjFunc.equalityItem.size() + constraintObjFunc.inequalityItem.size();
            addInequalityConstraint(constraintObjFunc, size, denseVector, denseVector2, denseMatrix, doubleValue2);
            for (int i2 = 0; i2 < intValue2; i2++) {
                try {
                    DenseVector denseVector3 = new DenseVector(constraintObjFunc.equalityItem.size() + intValue);
                    SqpPai.vecAddVec(denseVector, denseVector3, intValue);
                    DenseMatrix denseMatrix2 = new DenseMatrix(buildH(denseMatrix, denseVector2, denseVector3, constraintObjFunc.equalityConstraint, constraintObjFunc.equalityItem));
                    double normL1 = 1.0d / denseVector3.normL1();
                    denseMatrix2.scaleEqual(normL1);
                    denseVector3.scaleEqual(normL1);
                    DenseVector denseVector4 = new DenseVector(intValue);
                    SqpPai.vecAddVec(denseMatrix2.inverse().multiplies(denseVector3), denseVector4, intValue);
                    comContext.putObj(OptimVariable.dir, denseVector4);
                    break;
                } catch (Exception e) {
                    double d2 = this.l2 + doubleValue;
                    for (int i3 = i; i3 < intValue; i3++) {
                        d = Double.valueOf(d.doubleValue() + (0.5d * d2 * Math.pow(denseVector2.get(i3), 2.0d)));
                        denseVector.add(i3, d2 * denseVector2.get(i3));
                        denseMatrix.add(i3, i3, d2);
                    }
                    if (this.hasIntercept) {
                        d = Double.valueOf(d.doubleValue() + (0.5d * doubleValue * Math.pow(denseVector2.get(0), 2.0d)));
                        denseVector.add(0, doubleValue * denseVector2.get(0));
                        denseMatrix.add(0, 0, doubleValue);
                    }
                    doubleValue *= 10.0d;
                }
            }
            tuple2.f0 = denseVector;
            comContext.putObj(OptimVariable.grad, tuple2);
            comContext.putObj(OptimVariable.hessian, denseMatrix);
            comContext.putObj(ConstraintVariable.loss, Double.valueOf(constrainedLoss(d.doubleValue(), denseVector2, constraintObjFunc, doubleValue2, size)));
        }

        private static double[][] buildH(DenseMatrix denseMatrix, DenseVector denseVector, DenseVector denseVector2, DenseMatrix denseMatrix2, DenseVector denseVector3) {
            int size = denseVector3.size();
            int size2 = denseVector.size();
            int i = size + size2;
            double[][] dArr = new double[i][i];
            SqpUtil.fillMatrix(dArr, 0, 0, denseMatrix.getArrayCopy2D());
            for (int i2 = 0; i2 < size; i2++) {
                int i3 = i2 + size2;
                double d = 0.0d;
                for (int i4 = 0; i4 < size2; i4++) {
                    dArr[i3][i4] = denseMatrix2.get(i2, i4);
                    dArr[i4][i3] = denseMatrix2.get(i2, i4);
                    d += denseMatrix2.get(i2, i4) * denseVector.get(i4);
                }
                denseVector2.set(i3, d - denseVector3.get(i2));
            }
            return dArr;
        }

        private static double constrainedLoss(double d, DenseVector denseVector, ConstraintObjFunc constraintObjFunc, double d2, double d3) {
            if (d3 == Criteria.INVALID_GAIN) {
                return d;
            }
            int numRows = constraintObjFunc.inequalityConstraint.numRows();
            for (int i = 0; i < numRows; i++) {
                d -= d2 * Math.log(sumInequality(constraintObjFunc.inequalityConstraint, constraintObjFunc.inequalityItem, denseVector, i));
            }
            return d;
        }

        private static double sumInequality(DenseMatrix denseMatrix, DenseVector denseVector, DenseVector denseVector2, int i) {
            double[] data = denseVector2.getData();
            double d = denseVector.get(i);
            int numCols = denseMatrix.numCols();
            for (int i2 = 0; i2 < numCols; i2++) {
                d -= data[i2] * denseMatrix.get(i, i2);
            }
            if (d == Criteria.INVALID_GAIN) {
                d = 1.0E-6d;
            }
            return d;
        }

        private static void addInequalityConstraint(ConstraintObjFunc constraintObjFunc, int i, DenseVector denseVector, DenseVector denseVector2, DenseMatrix denseMatrix, double d) {
            if (i == 0) {
                return;
            }
            updateGradForInequalityConstraint(constraintObjFunc, denseVector, denseVector2, denseMatrix, 1.0d / d);
        }

        private static void updateGradForInequalityConstraint(ConstraintObjFunc constraintObjFunc, DenseVector denseVector, DenseVector denseVector2, DenseMatrix denseMatrix, double d) {
            DenseMatrix denseMatrix2 = constraintObjFunc.inequalityConstraint;
            int numRows = denseMatrix2.numRows();
            int numCols = denseMatrix2.numCols();
            for (int i = 0; i < numRows; i++) {
                double sumInequality = sumInequality(denseMatrix2, constraintObjFunc.inequalityItem, denseVector2, i);
                for (int i2 = 0; i2 < numCols; i2++) {
                    denseVector.add(i2, (d * denseMatrix2.get(i, i2)) / sumInequality);
                }
            }
            for (int i3 = 0; i3 < numRows; i3++) {
                double pow = Math.pow(sumInequality(denseMatrix2, constraintObjFunc.inequalityItem, denseVector2, i3), 2.0d);
                for (int i4 = 0; i4 < numCols; i4++) {
                    double d2 = denseMatrix2.get(i3, i4);
                    for (int i5 = 0; i5 < numCols; i5++) {
                        denseMatrix.add(i4, i5, ((d * d2) * denseMatrix2.get(i3, i5)) / pow);
                    }
                }
            }
        }
    }

    public LogBarrier(DataSet<OptimObjFunc> dataSet, DataSet<Tuple3<Double, Double, Vector>> dataSet2, DataSet<Integer> dataSet3, Params params) {
        super(dataSet, dataSet2, dataSet3, params);
    }

    @Override // com.alibaba.alink.operator.common.optim.Optimizer
    public DataSet<Tuple2<DenseVector, double[]>> optimize() {
        this.coefVec = this.coefDim.map(new Sqp.InitialCoef()).withBroadcastSet(this.objFuncSet, OptimVariable.objFunc);
        return new IterativeComQueue().initWithPartitionedData("trainData", this.trainData).initWithBroadcastData(OptimVariable.model, this.coefVec).initWithBroadcastData(OptimVariable.objFunc, this.objFuncSet).initWithBroadcastData(ConstraintVariable.weightDim, this.coefDim).add(new InitializeParams()).add(new PreallocateVector(OptimVariable.grad, new double[2])).add(new PreallocateMatrix(OptimVariable.hessian, MAX_FEATURE_NUM)).add(new Sqp.CalcGradAndHessian()).add(new AllReduce(OptimVariable.gradHessAllReduce)).add(new Sqp.GetGradientAndHessian()).add(new RunNewtonStep(this.params)).add(new Sqp.LineSearch(((Double) this.params.get(HasL2.L_2)).doubleValue())).add(new AllReduce("lossAllReduce")).add(new Sqp.GetMinCoef()).add(new CalcConvergence()).setCompareCriterionOfNode0((CompareCriterionFunction) new IterTermination(((Integer) this.params.get(HasMaxIterDefaultAs100.MAX_ITER)).intValue(), ((Double) this.params.get(HasEpsilonDefaultAs0000001.EPSILON)).doubleValue())).closeWith(new Sqp.BuildModel()).exec().mapPartition(new Sqp.ParseRowModel());
    }
}
