package com.alibaba.alink.operator.common.optim.subfunc;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/subfunc/CalcGradient.class */
public class CalcGradient extends ComputeFunction {
    private static final long serialVersionUID = 5489606217395507860L;
    private OptimObjFunc objFunc;

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        Iterable<Tuple3<Double, Double, Vector>> iterable = (Iterable) comContext.getObj("trainData");
        Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.currentCoef);
        int size = ((DenseVector) tuple2.f0).size();
        DenseVector denseVector = (DenseVector) tuple2.f0;
        if (this.objFunc == null) {
            this.objFunc = (OptimObjFunc) ((List) comContext.getObj(OptimVariable.objFunc)).get(0);
        }
        Tuple2 tuple22 = (Tuple2) comContext.getObj(OptimVariable.dir);
        double calcGradient = this.objFunc.calcGradient(iterable, denseVector, (DenseVector) tuple22.f0);
        double[] dArr = (double[]) comContext.getObj(OptimVariable.gradAllReduce);
        if (dArr == null) {
            dArr = new double[size + 1];
            comContext.putObj(OptimVariable.gradAllReduce, dArr);
        }
        for (int i = 0; i < size; i++) {
            dArr[i] = ((DenseVector) tuple22.f0).get(i) * calcGradient;
        }
        dArr[size] = calcGradient;
    }
}
