package com.alibaba.alink.operator.common.optim.subfunc;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/subfunc/CalcLosses.class */
public class CalcLosses extends ComputeFunction {
    private static final long serialVersionUID = 4621851063717484179L;
    private OptimObjFunc objFunc;
    private final LinearTrainParams.OptimMethod method;
    private final int numSearchStep;

    public CalcLosses(LinearTrainParams.OptimMethod optimMethod, int i) {
        this.method = optimMethod;
        this.numSearchStep = i;
    }

    @Override // com.alibaba.alink.common.comqueue.ComputeFunction
    public void calc(ComContext comContext) {
        List<Tuple3<Double, Double, Vector>> list = (List) comContext.getObj("trainData");
        Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.dir);
        Tuple2 tuple22 = (Tuple2) comContext.getObj(OptimVariable.currentCoef);
        if (this.objFunc == null) {
            this.objFunc = (OptimObjFunc) ((List) comContext.getObj(OptimVariable.objFunc)).get(0);
        }
        Double valueOf = Double.valueOf(((double[]) tuple2.f1)[1] / this.numSearchStep);
        double[] constraintCalcSearchValues = this.method.equals(LinearTrainParams.OptimMethod.OWLQN) ? this.objFunc.constraintCalcSearchValues(list, (DenseVector) tuple22.f0, (DenseVector) tuple2.f0, valueOf.doubleValue(), this.numSearchStep) : this.objFunc.calcSearchValues(list, (DenseVector) tuple22.f0, (DenseVector) tuple2.f0, valueOf.doubleValue(), this.numSearchStep);
        double[] dArr = (double[]) comContext.getObj("lossAllReduce");
        if (dArr == null) {
            comContext.putObj("lossAllReduce", (double[]) constraintCalcSearchValues.clone());
        } else {
            System.arraycopy(constraintCalcSearchValues, 0, dArr, 0, constraintCalcSearchValues.length);
        }
    }
}
