package com.alibaba.alink.operator.common.optim;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompareCriterionFunction;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.optim.subfunc.CalcGradient;
import com.alibaba.alink.operator.common.optim.subfunc.CalcLosses;
import com.alibaba.alink.operator.common.optim.subfunc.IterTermination;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.optim.subfunc.OutputModel;
import com.alibaba.alink.operator.common.optim.subfunc.ParseRowModel;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateCoefficient;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateConvergenceInfo;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateSkyk;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateVector;
import com.alibaba.alink.operator.common.optim.subfunc.UpdateModel;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import com.alibaba.alink.params.shared.optim.HasNumSearchStepDefaultAs4;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/Lbfgs.class */
public class Lbfgs extends Optimizer {

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/Lbfgs$CalDirection.class */
    public static class CalDirection extends ComputeFunction {
        private static final long serialVersionUID = -4061612963118027380L;
        private transient DenseVector oldGradient;
        private transient double[] alpha;
        private final int m;

        private CalDirection(int i) {
            this.m = i;
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.grad);
            Tuple2 tuple22 = (Tuple2) comContext.getObj(OptimVariable.dir);
            Tuple2 tuple23 = (Tuple2) comContext.getObj(OptimVariable.sKyK);
            int size = ((DenseVector) tuple2.f0).size();
            double[] dArr = (double[]) comContext.getObj(OptimVariable.gradAllReduce);
            if (this.oldGradient == null) {
                this.oldGradient = new DenseVector(size);
            }
            Vector[] vectorArr = (DenseVector[]) tuple23.f0;
            DenseVector[] denseVectorArr = (DenseVector[]) tuple23.f1;
            for (int i = 0; i < size; i++) {
                ((DenseVector) tuple2.f0).set(i, dArr[i] / dArr[size]);
            }
            ((double[]) tuple22.f1)[0] = dArr[size];
            int stepNo = comContext.getStepNo() - 1;
            if (stepNo == 0) {
                ((DenseVector) tuple22.f0).setEqual((DenseVector) tuple2.f0);
                this.oldGradient.setEqual((DenseVector) tuple2.f0);
            } else {
                denseVectorArr[(stepNo - 1) % this.m].setEqual((DenseVector) tuple2.f0);
                denseVectorArr[(stepNo - 1) % this.m].minusEqual(this.oldGradient);
                this.oldGradient.setEqual((DenseVector) tuple2.f0);
            }
            ((DenseVector) tuple22.f0).setEqual((DenseVector) tuple2.f0);
            int i2 = stepNo > this.m ? stepNo - this.m : 0;
            int min = Math.min(stepNo, this.m);
            if (this.alpha == null) {
                this.alpha = new double[this.m];
            }
            for (int i3 = min - 1; i3 >= 0; i3--) {
                int i4 = (i3 + i2) % this.m;
                double dot = vectorArr[i4].dot(denseVectorArr[i4]);
                if (Math.abs(dot) > Criteria.INVALID_GAIN) {
                    this.alpha[i3] = (1.0d / dot) * vectorArr[i4].dot((Vector) tuple22.f0);
                    ((DenseVector) tuple22.f0).plusScaleEqual(denseVectorArr[i4], -this.alpha[i3]);
                }
            }
            for (int i5 = 0; i5 < min; i5++) {
                int i6 = (i5 + i2) % this.m;
                double dot2 = vectorArr[i6].dot(denseVectorArr[i6]);
                if (Math.abs(dot2) > Criteria.INVALID_GAIN) {
                    ((DenseVector) tuple22.f0).plusScaleEqual(vectorArr[i6], this.alpha[i5] - ((1.0d / dot2) * denseVectorArr[i6].dot((Vector) tuple22.f0)));
                }
            }
        }
    }

    public Lbfgs(DataSet<OptimObjFunc> dataSet, DataSet<Tuple3<Double, Double, Vector>> dataSet2, DataSet<Integer> dataSet3, Params params) {
        super(dataSet, dataSet2, dataSet3, params);
    }

    @Override // com.alibaba.alink.operator.common.optim.Optimizer
    public DataSet<Tuple2<DenseVector, double[]>> optimize() {
        int intValue = ((Integer) this.params.get(LinearTrainParams.MAX_ITER)).intValue();
        int intValue2 = ((Integer) this.params.get(HasNumSearchStepDefaultAs4.NUM_SEARCH_STEP)).intValue();
        checkInitCoef();
        return new IterativeComQueue().initWithPartitionedData("trainData", this.trainData).initWithBroadcastData(OptimVariable.model, this.coefVec).initWithBroadcastData(OptimVariable.objFunc, this.objFuncSet).add(new PreallocateCoefficient(OptimVariable.currentCoef)).add(new PreallocateCoefficient(OptimVariable.minCoef)).add(new PreallocateConvergenceInfo(OptimVariable.convergenceInfo, intValue)).add(new PreallocateVector(OptimVariable.dir, new double[]{Criteria.INVALID_GAIN, 0.1d})).add(new PreallocateVector(OptimVariable.grad)).add(new PreallocateSkyk(10)).add(new CalcGradient()).add(new AllReduce(OptimVariable.gradAllReduce)).add(new CalDirection(10)).add(new CalcLosses(LinearTrainParams.OptimMethod.LBFGS, intValue2)).add(new AllReduce("lossAllReduce")).add(new UpdateModel(this.params, OptimVariable.grad, LinearTrainParams.OptimMethod.LBFGS, intValue2)).setCompareCriterionOfNode0((CompareCriterionFunction) new IterTermination()).closeWith(new OutputModel()).setMaxIter(intValue).exec().mapPartition(new ParseRowModel());
    }
}
