package com.alibaba.alink.operator.common.optim;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompareCriterionFunction;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.optim.subfunc.IterTermination;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.optim.subfunc.OutputModel;
import com.alibaba.alink.operator.common.optim.subfunc.ParseRowModel;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateCoefficient;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateConvergenceInfo;
import com.alibaba.alink.operator.common.optim.subfunc.PreallocateVector;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import com.alibaba.alink.params.shared.optim.SgdParams;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/Sgd.class */
public class Sgd extends Optimizer {

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/Sgd$CalcSubGradient.class */
    public static class CalcSubGradient extends ComputeFunction {
        private static final long serialVersionUID = -5611469215969052818L;
        private OptimObjFunc objFunc;
        private final double fraction;
        private transient List<Tuple3<Double, Double, Vector>> data = null;
        private transient List<Tuple3<Double, Double, Vector>> miniBatchData = null;
        private Random random = null;
        private int batchSize;
        private int totalSize;

        public CalcSubGradient(double d) {
            this.fraction = d;
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            if (this.data == null) {
                this.data = (List) comContext.getObj("trainData");
                this.random = new Random();
                this.batchSize = Double.valueOf(this.data.size() * this.fraction).intValue();
                this.miniBatchData = new ArrayList(this.batchSize);
                this.totalSize = this.data.size();
                for (int i = 0; i < this.batchSize; i++) {
                    this.miniBatchData.add(this.data.get(this.random.nextInt(this.totalSize)));
                }
            } else {
                for (int i2 = 0; i2 < this.batchSize; i2++) {
                    this.miniBatchData.set(i2, this.data.get(this.random.nextInt(this.totalSize)));
                }
            }
            if (this.objFunc == null) {
                this.objFunc = (OptimObjFunc) ((List) comContext.getObj(OptimVariable.objFunc)).get(0);
            }
            Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.minCoef);
            int size = ((DenseVector) tuple2.f0).size();
            Tuple2 tuple22 = (Tuple2) comContext.getObj(OptimVariable.dir);
            for (int i3 = 0; i3 < size; i3++) {
                ((DenseVector) tuple22.f0).set(i3, Criteria.INVALID_GAIN);
            }
            double calcGradient = this.objFunc.calcGradient(this.miniBatchData, (DenseVector) tuple2.f0, (DenseVector) tuple22.f0);
            double doubleValue = ((Double) this.objFunc.calcObjValue(this.miniBatchData, (DenseVector) tuple2.f0).f0).doubleValue();
            ((double[]) tuple22.f1)[0] = calcGradient;
            ((double[]) tuple22.f1)[1] = doubleValue;
            double[] dArr = (double[]) comContext.getObj(OptimVariable.gradAllReduce);
            if (dArr == null) {
                dArr = new double[size + 2];
                comContext.putObj(OptimVariable.gradAllReduce, dArr);
            }
            for (int i4 = 0; i4 < size; i4++) {
                dArr[i4] = ((DenseVector) tuple22.f0).get(i4) * ((double[]) tuple22.f1)[0];
            }
            dArr[size] = ((double[]) tuple22.f1)[0];
            dArr[size + 1] = ((double[]) tuple22.f1)[1];
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/Sgd$GetGradient.class */
    public static class GetGradient extends ComputeFunction {
        private static final long serialVersionUID = -9048669840419848014L;

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.dir);
            int size = ((DenseVector) tuple2.f0).size();
            double[] dArr = (double[]) comContext.getObj(OptimVariable.gradAllReduce);
            for (int i = 0; i < size; i++) {
                ((DenseVector) tuple2.f0).set(i, dArr[i] / dArr[size]);
            }
            ((double[]) tuple2.f1)[0] = dArr[size];
            ((double[]) tuple2.f1)[1] = dArr[size + 1];
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/Sgd$UpdateSgdModel.class */
    public static class UpdateSgdModel extends ComputeFunction {
        private static final Logger LOG = LoggerFactory.getLogger(UpdateSgdModel.class);
        private static final long serialVersionUID = 2064300908719020918L;
        private final double epsilon;
        private final int maxIter;
        private final double learnRate;
        private final LinearTrainParams.OptimMethod method;

        public UpdateSgdModel(int i, double d, double d2, LinearTrainParams.OptimMethod optimMethod) {
            this.maxIter = i;
            this.epsilon = d;
            this.method = optimMethod;
            this.learnRate = d2;
        }

        @Override // com.alibaba.alink.common.comqueue.ComputeFunction
        public void calc(ComContext comContext) {
            Tuple2<DenseVector, double[]> tuple2 = (Tuple2) comContext.getObj(OptimVariable.dir);
            Tuple2 tuple22 = (Tuple2) comContext.getObj(OptimVariable.minCoef);
            double normInf = this.learnRate / (((DenseVector) tuple2.f0).normInf() + Math.sqrt(comContext.getStepNo()));
            ((DenseVector) tuple22.f0).plusScaleEqual((Vector) tuple2.f0, -normInf);
            filter(tuple2, comContext, normInf);
        }

        public void filter(Tuple2<DenseVector, double[]> tuple2, ComContext comContext, double d) {
            double normL2 = ((DenseVector) ((Tuple2) comContext.getObj(OptimVariable.dir)).f0).normL2();
            if (comContext.getTaskId() == 0 && AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println(this.method.toString() + " method continue at step : " + comContext.getStepNo() + " cur loss : " + (((double[]) tuple2.f1)[1] / ((double[]) tuple2.f1)[0]) + " grad norm : " + normL2 + " learning rate : " + d);
            }
            if (normL2 < this.epsilon) {
                LOG.info(this.method.toString() + " method converged at step : : {}, grad norm: {}", Integer.valueOf(comContext.getStepNo()), Double.valueOf(normL2));
                ((double[]) tuple2.f1)[0] = -1.0d;
            } else if (comContext.getStepNo() > this.maxIter - 1) {
                LOG.info(this.method.toString() + " method stop at max step : : {}, grad norm: {}", Integer.valueOf(comContext.getStepNo()), Double.valueOf(normL2));
                ((double[]) tuple2.f1)[0] = -1.0d;
            }
        }
    }

    public Sgd(DataSet<OptimObjFunc> dataSet, DataSet<Tuple3<Double, Double, Vector>> dataSet2, DataSet<Integer> dataSet3, Params params) {
        super(dataSet, dataSet2, dataSet3, params);
    }

    @Override // com.alibaba.alink.operator.common.optim.Optimizer
    public DataSet<Tuple2<DenseVector, double[]>> optimize() {
        int intValue = ((Integer) this.params.get(SgdParams.MAX_ITER)).intValue();
        double doubleValue = ((Double) this.params.get(SgdParams.LEARNING_RATE)).doubleValue();
        double doubleValue2 = ((Double) this.params.get(SgdParams.MINI_BATCH_FRACTION)).doubleValue();
        double doubleValue3 = ((Double) this.params.get(SgdParams.EPSILON)).doubleValue();
        checkInitCoef();
        return new IterativeComQueue().initWithPartitionedData("trainData", this.trainData).initWithBroadcastData(OptimVariable.model, this.coefVec).initWithBroadcastData(OptimVariable.objFunc, this.objFuncSet).add(new PreallocateCoefficient(OptimVariable.minCoef)).add(new PreallocateConvergenceInfo(OptimVariable.convergenceInfo, intValue)).add(new PreallocateVector(OptimVariable.dir, new double[2])).add(new CalcSubGradient(doubleValue2)).add(new AllReduce(OptimVariable.gradAllReduce)).add(new GetGradient()).add(new UpdateSgdModel(intValue, doubleValue3, doubleValue, LinearTrainParams.OptimMethod.SGD)).setCompareCriterionOfNode0((CompareCriterionFunction) new IterTermination()).closeWith(new OutputModel()).setMaxIter(intValue).exec().mapPartition(new ParseRowModel());
    }
}
