package com.alibaba.alink.operator.common.optim.activeSet;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompareCriterionFunction;
import com.alibaba.alink.common.comqueue.CompleteResultFunction;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.operator.common.optim.Optimizer;
import com.alibaba.alink.operator.common.optim.local.LocalSqp;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateMatrix;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateVector;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs100;
import com.alibaba.alink.params.shared.linear.HasEpsilonDefaultAs0000001;
import com.alibaba.alink.params.shared.linear.HasL2;
import com.alibaba.alink.params.shared.linear.HasWithIntercept;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.linear.LinearConstraint;
import org.apache.commons.math3.optim.linear.LinearConstraintSet;
import org.apache.commons.math3.optim.linear.LinearObjectiveFunction;
import org.apache.commons.math3.optim.linear.Relationship;
import org.apache.commons.math3.optim.linear.SimplexSolver;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp.class */
public class Sqp extends Optimizer {
    private static final int MAX_FEATURE_NUM = 3000;

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp$BuildModel.class */
    public static class BuildModel extends CompleteResultFunction {
        private static final long serialVersionUID = -7967444945852772659L;

        @Override // com.alibaba.alink.common.comqueue.CompleteResultFunction
        public List<Row> calc(ComContext comContext) {
            if (comContext.getTaskId() != 0) {
                return null;
            }
            DenseVector denseVector = (DenseVector) comContext.getObj(ConstraintVariable.weight);
            double[] dArr = new double[2];
            if (comContext.containsObj(ConstraintVariable.lastLoss)) {
                dArr[0] = ((Double) comContext.getObj(ConstraintVariable.lastLoss)).doubleValue();
            }
            if (comContext.containsObj(ConstraintVariable.loss)) {
                dArr[1] = ((Double) comContext.getObj(ConstraintVariable.loss)).doubleValue();
            }
            Params params = new Params();
            params.set((ParamInfo<ParamInfo<DenseVector>>) ModelParamName.COEF, (ParamInfo<DenseVector>) denseVector);
            params.set((ParamInfo<ParamInfo<double[]>>) ModelParamName.LOSS_CURVE, (ParamInfo<double[]>) dArr);
            ArrayList arrayList = new ArrayList(1);
            arrayList.add(Row.of(new Object[]{params.toJson()}));
            return arrayList;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp$CalcConvergence.class */
    public static class CalcConvergence extends ComputeFunction {
        private static final long serialVersionUID = 1936163292416518616L;

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            ConstraintObjFunc constraintObjFunc = (ConstraintObjFunc) ((List) comContext.getObj(OptimVariable.objFunc)).get(0);
            constraintObjFunc.equalityItem = (DenseVector) comContext.getObj(SqpVariable.ecmBias);
            constraintObjFunc.inequalityItem = (DenseVector) comContext.getObj(SqpVariable.icmBias);
            if (comContext.getStepNo() != 1) {
                double doubleValue = ((Double) comContext.getObj(ConstraintVariable.loss)).doubleValue();
                double doubleValue2 = ((Double) comContext.getObj(ConstraintVariable.lastLoss)).doubleValue();
                int stepNo = comContext.getStepNo();
                comContext.putObj(ConstraintVariable.convergence, Double.valueOf(stepNo <= 5 ? (doubleValue2 - doubleValue) / (Math.abs(doubleValue) * stepNo) : (doubleValue2 - doubleValue) / (Math.abs(doubleValue) * 5)));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp$CalcDir.class */
    public static class CalcDir extends ComputeFunction {
        private static final long serialVersionUID = 7694040433763461801L;
        private boolean hasIntercept;
        private double l2;

        public CalcDir(Params params) {
            this.hasIntercept = ((Boolean) params.get(HasWithIntercept.WITH_INTERCEPT)).booleanValue();
            this.l2 = ((Double) params.get(HasL2.L_2)).doubleValue();
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            ConstraintObjFunc constraintObjFunc = (ConstraintObjFunc) ((List) comContext.getObj(OptimVariable.objFunc)).get(0);
            Object obj = (Double) comContext.getObj(ConstraintVariable.loss);
            Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.grad);
            DenseVector denseVector = (DenseVector) tuple2.f0;
            DenseMatrix denseMatrix = (DenseMatrix) comContext.getObj(OptimVariable.hessian);
            int intValue = ((Integer) ((List) comContext.getObj(ConstraintVariable.weightDim)).get(0)).intValue();
            int intValue2 = ((Integer) comContext.getObj(ConstraintVariable.newtonRetryTime)).intValue();
            double doubleValue = ((Double) comContext.getObj(ConstraintVariable.minL2Weight)).doubleValue();
            DenseVector denseVector2 = (DenseVector) comContext.getObj(ConstraintVariable.weight);
            DenseVector startDir = SqpPai.getStartDir(constraintObjFunc, denseVector2, (DenseVector) comContext.getObj(SqpVariable.icmBias), (DenseVector) comContext.getObj(SqpVariable.ecmBias));
            Tuple3<DenseVector, DenseVector, DenseMatrix> calcDir = SqpPai.calcDir(intValue2, intValue, constraintObjFunc, startDir, denseVector2, denseMatrix, denseVector, this.l2, doubleValue, this.hasIntercept, SqpPai.getActiveSet(constraintObjFunc.inequalityConstraint, constraintObjFunc.inequalityItem, startDir, intValue));
            Object obj2 = (DenseVector) calcDir.f0;
            tuple2.f0 = calcDir.f1;
            Object obj3 = (DenseMatrix) calcDir.f2;
            comContext.putObj(OptimVariable.grad, tuple2);
            comContext.putObj(OptimVariable.hessian, obj3);
            comContext.putObj(ConstraintVariable.loss, obj);
            comContext.putObj(OptimVariable.dir, obj2);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp$CalcGradAndHessian.class */
    public static class CalcGradAndHessian extends ComputeFunction {
        private static final long serialVersionUID = 4760392853024920737L;
        private OptimObjFunc objFunc;

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            Iterable<Tuple3<Double, Double, Vector>> iterable = (Iterable) comContext.getObj("trainData");
            Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.grad);
            DenseVector denseVector = (DenseVector) comContext.getObj(ConstraintVariable.weight);
            DenseMatrix denseMatrix = (DenseMatrix) comContext.getObj(OptimVariable.hessian);
            int size = ((DenseVector) tuple2.f0).size();
            if (this.objFunc == null) {
                this.objFunc = (OptimObjFunc) ((List) comContext.getObj(OptimVariable.objFunc)).get(0);
            }
            Tuple2<Double, Double> calcHessianGradientLoss = this.objFunc.calcHessianGradientLoss(iterable, denseVector, denseMatrix, (DenseVector) tuple2.f0);
            double[] dArr = (double[]) comContext.getObj(OptimVariable.gradHessAllReduce);
            if (dArr == null) {
                dArr = new double[size + (size * size) + 2];
                comContext.putObj(OptimVariable.gradHessAllReduce, dArr);
            }
            for (int i = 0; i < size; i++) {
                dArr[i] = ((DenseVector) tuple2.f0).get(i);
                for (int i2 = 0; i2 < size; i2++) {
                    dArr[((i + 1) * size) + i2] = denseMatrix.get(i, i2);
                }
            }
            dArr[size + (size * size)] = ((Double) calcHessianGradientLoss.f0).doubleValue();
            dArr[size + (size * size) + 1] = ((Double) calcHessianGradientLoss.f1).doubleValue();
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp$GetGradientAndHessian.class */
    public static class GetGradientAndHessian extends ComputeFunction {
        private static final long serialVersionUID = 2724626183370161805L;

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.grad);
            DenseMatrix denseMatrix = (DenseMatrix) comContext.getObj(OptimVariable.hessian);
            int size = ((DenseVector) tuple2.f0).size();
            double[] dArr = (double[]) comContext.getObj(OptimVariable.gradHessAllReduce);
            ((double[]) tuple2.f1)[0] = dArr[size + (size * size)];
            for (int i = 0; i < size; i++) {
                ((DenseVector) tuple2.f0).set(i, dArr[i] / ((double[]) tuple2.f1)[0]);
                for (int i2 = 0; i2 < size; i2++) {
                    denseMatrix.set(i, i2, dArr[((i + 1) * size) + i2] / ((double[]) tuple2.f1)[0]);
                }
            }
            ((double[]) tuple2.f1)[0] = dArr[size + (size * size)];
            ((double[]) tuple2.f1)[1] = dArr[(size + (size * size)) + 1] / ((double[]) tuple2.f1)[0];
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp$GetMinCoef.class */
    public static class GetMinCoef extends ComputeFunction {
        private static final long serialVersionUID = 239058213400494835L;

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            double[] dArr = (double[]) comContext.getObj("lossAllReduce");
            Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.grad);
            DenseVector denseVector = (DenseVector) comContext.getObj(OptimVariable.dir);
            DenseVector denseVector2 = (DenseVector) comContext.getObj(ConstraintVariable.weight);
            double doubleValue = LocalSqp.lineSearch(dArr, denseVector2, (DenseVector) tuple2.f0, denseVector).doubleValue();
            comContext.putObj(ConstraintVariable.weight, denseVector2);
            if (comContext.getStepNo() != 1) {
                comContext.putObj(ConstraintVariable.lastLoss, comContext.getObj(ConstraintVariable.loss));
            }
            comContext.putObj(ConstraintVariable.loss, Double.valueOf(doubleValue));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp$InitialCoef.class */
    public static class InitialCoef extends RichMapFunction<Integer, DenseVector> {
        private static final long serialVersionUID = -1725328800337420019L;
        private double[][] equalMatrix;
        private double[] equalItem;
        private double[][] inequalMatrix;
        private double[] inequalItem;

        public void open(Configuration configuration) throws Exception {
            ConstraintObjFunc constraintObjFunc = (ConstraintObjFunc) getRuntimeContext().getBroadcastVariable(OptimVariable.objFunc).get(0);
            this.inequalMatrix = constraintObjFunc.inequalityConstraint.getArrayCopy2D();
            this.inequalItem = constraintObjFunc.inequalityItem.getData();
            this.equalMatrix = constraintObjFunc.equalityConstraint.getArrayCopy2D();
            this.equalItem = constraintObjFunc.equalityItem.getData();
        }

        public DenseVector map(Integer num) throws Exception {
            return new DenseVector(Sqp.phaseOne(this.equalMatrix, this.equalItem, this.inequalMatrix, this.inequalItem, num.intValue()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp$InitializeParams.class */
    public static class InitializeParams extends ComputeFunction {
        private static final long serialVersionUID = 3775769468090056172L;

        private InitializeParams() {
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            if (comContext.getStepNo() == 1) {
                comContext.putObj(ConstraintVariable.weight, (DenseVector) ((List) comContext.getObj(OptimVariable.model)).get(0));
                comContext.putObj(ConstraintVariable.minL2Weight, Double.valueOf(1.0E-8d));
                comContext.putObj(ConstraintVariable.linearSearchTimes, 40);
                comContext.putObj(ConstraintVariable.newtonRetryTime, 12);
                comContext.putObj(ConstraintVariable.loss, Double.valueOf(Criteria.INVALID_GAIN));
                ConstraintObjFunc constraintObjFunc = (ConstraintObjFunc) ((List) comContext.getObj(OptimVariable.objFunc)).get(0);
                comContext.putObj(SqpVariable.icmBias, constraintObjFunc.inequalityItem);
                comContext.putObj(SqpVariable.ecmBias, constraintObjFunc.equalityItem);
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp$IterTermination.class */
    public static class IterTermination extends CompareCriterionFunction {
        private static final long serialVersionUID = -9142037869276356229L;
        private double epsilon;

        IterTermination(double d) {
            this.epsilon = d;
        }

        @Override // com.alibaba.alink.common.comqueue.CompareCriterionFunction
        public boolean calc(ComContext comContext) {
            return comContext.getStepNo() != 1 && Math.abs(((Double) comContext.getObj(ConstraintVariable.convergence)).doubleValue()) <= this.epsilon;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp$LineSearch.class */
    public static class LineSearch extends ComputeFunction {
        private static final long serialVersionUID = 2611682666208211053L;
        private ConstraintObjFunc objFunc;
        private double l2Weight;

        public LineSearch(double d) {
            this.l2Weight = d;
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            this.objFunc = (ConstraintObjFunc) ((List) comContext.getObj(OptimVariable.objFunc)).get(0);
            DenseVector denseVector = (DenseVector) comContext.getObj(OptimVariable.dir);
            int intValue = ((Integer) comContext.getObj(ConstraintVariable.linearSearchTimes)).intValue();
            double doubleValue = ((Double) comContext.getObj(ConstraintVariable.minL2Weight)).doubleValue();
            if (this.l2Weight == Criteria.INVALID_GAIN) {
                this.l2Weight += doubleValue;
            }
            comContext.putObj("lossAllReduce", this.objFunc.calcLineSearch((Iterable) comContext.getObj("trainData"), (DenseVector) comContext.getObj(ConstraintVariable.weight), denseVector, intValue, this.l2Weight));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/activeSet/Sqp$ParseRowModel.class */
    public static class ParseRowModel extends RichMapPartitionFunction<Row, Tuple2<DenseVector, double[]>> {
        private static final long serialVersionUID = 7590757264182157832L;

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<DenseVector, double[]>> collector) throws Exception {
            DenseVector denseVector = null;
            double[] dArr = null;
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    Params fromJson = Params.fromJson((String) it.next().getField(0));
                    denseVector = (DenseVector) fromJson.get(ModelParamName.COEF);
                    dArr = (double[]) fromJson.get(ModelParamName.LOSS_CURVE);
                }
                if (denseVector != null) {
                    collector.collect(Tuple2.of(denseVector, dArr));
                }
            }
        }
    }

    public Sqp(DataSet<OptimObjFunc> dataSet, DataSet<Tuple3<Double, Double, Vector>> dataSet2, DataSet<Integer> dataSet3, Params params) {
        super(dataSet, dataSet2, dataSet3, params);
    }

    @Override // com.alibaba.alink.operator.common.optim.Optimizer
    public DataSet<Tuple2<DenseVector, double[]>> optimize() {
        int intValue = ((Integer) this.params.get(HasMaxIterDefaultAs100.MAX_ITER)).intValue();
        this.coefVec = this.coefDim.map(new InitialCoef()).withBroadcastSet(this.objFuncSet, OptimVariable.objFunc);
        return new IterativeComQueue().initWithPartitionedData("trainData", this.trainData).initWithBroadcastData(OptimVariable.model, this.coefVec).initWithBroadcastData(OptimVariable.objFunc, this.objFuncSet).initWithBroadcastData(ConstraintVariable.weightDim, this.coefDim).add(new InitializeParams()).add(new PreallocateVector(OptimVariable.grad, new double[2])).add(new PreallocateMatrix(OptimVariable.hessian, MAX_FEATURE_NUM)).add(new CalcGradAndHessian()).add(new AllReduce(OptimVariable.gradHessAllReduce)).add(new GetGradientAndHessian()).add(new CalcDir(this.params)).add(new LineSearch(((Double) this.params.get(HasL2.L_2)).doubleValue())).add(new AllReduce("lossAllReduce")).add(new GetMinCoef()).add(new CalcConvergence()).setCompareCriterionOfNode0((CompareCriterionFunction) new IterTermination(((Double) this.params.get(HasEpsilonDefaultAs0000001.EPSILON)).doubleValue())).setMaxIter(intValue).closeWith(new BuildModel()).exec().mapPartition(new ParseRowModel());
    }

    public static double[] phaseOne(double[][] dArr, double[] dArr2, double[][] dArr3, double[] dArr4, int i) {
        int length = dArr2 != null ? dArr2.length : 0;
        int length2 = dArr4 != null ? dArr4.length : 0;
        int i2 = length + length2;
        if (i2 == 0) {
            double[] dArr5 = new double[i];
            Arrays.fill(dArr5, 1.0E-4d);
            return dArr5;
        }
        double[] dArr6 = new double[i + i2];
        Arrays.fill(dArr6, i, i + i2, 1.0d);
        OptimizationData linearObjectiveFunction = new LinearObjectiveFunction(dArr6, Criteria.INVALID_GAIN);
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < length; i3++) {
            double[] dArr7 = new double[i + i2];
            System.arraycopy(dArr[i3], 0, dArr7, 0, i);
            dArr7[i3 + i] = 1.0d;
            arrayList.add(new LinearConstraint(dArr7, Relationship.EQ, dArr2[i3]));
        }
        for (int i4 = 0; i4 < length2; i4++) {
            double[] dArr8 = new double[i + i2];
            System.arraycopy(dArr3[i4], 0, dArr8, 0, i);
            dArr8[i4 + i + length] = 1.0d;
            arrayList.add(new LinearConstraint(dArr8, Relationship.GEQ, dArr4[i4]));
        }
        for (int i5 = i; i5 < i + i2; i5++) {
            double[] dArr9 = new double[i + i2];
            dArr9[i5] = 1.0d;
            arrayList.add(new LinearConstraint(dArr9, Relationship.GEQ, Criteria.INVALID_GAIN));
        }
        PointValuePair optimize = new SimplexSolver().optimize(new OptimizationData[]{linearObjectiveFunction, new LinearConstraintSet(arrayList), GoalType.MINIMIZE});
        double[] dArr10 = new double[i];
        System.arraycopy(optimize.getPoint(), 0, dArr10, 0, i);
        return dArr10;
    }
}
