package com.alibaba.alink.operator.common.linear;

import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.linear.unarylossfunc.UnaryLossFunc;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/linear/UnaryLossObjFunc.class */
public class UnaryLossObjFunc extends OptimObjFunc {
    private static final long serialVersionUID = 1178693053439209380L;
    private final UnaryLossFunc unaryLossFunc;

    public UnaryLossObjFunc(UnaryLossFunc unaryLossFunc, Params params) {
        super(params);
        this.unaryLossFunc = unaryLossFunc;
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public boolean hasSecondDerivative() {
        return true;
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public double calcLoss(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector) {
        return this.unaryLossFunc.loss(getEta(tuple3, denseVector), ((Double) tuple3.f1).doubleValue());
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public void updateGradient(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseVector denseVector2) {
        denseVector2.plusScaleEqual((Vector) tuple3.f2, ((Double) tuple3.f0).doubleValue() * this.unaryLossFunc.derivative(getEta(tuple3, denseVector), ((Double) tuple3.f1).doubleValue()));
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public void updateHessian(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseMatrix denseMatrix) {
        Vector vector = (Vector) tuple3.f2;
        double secondDerivative = this.unaryLossFunc.secondDerivative(getEta(tuple3, denseVector), ((Double) tuple3.f1).doubleValue()) * ((Double) tuple3.f0).doubleValue();
        if (vector instanceof DenseVector) {
            int numCols = denseMatrix.numCols();
            int numRows = denseMatrix.numRows();
            double[] data = denseMatrix.getData();
            double[] data2 = ((DenseVector) vector).getData();
            int i = 0;
            for (int i2 = 0; i2 < numCols; i2++) {
                for (int i3 = 0; i3 < numRows; i3++) {
                    int i4 = i;
                    i++;
                    data[i4] = data[i4] + (data2[i3] * data2[i2] * secondDerivative);
                }
            }
            return;
        }
        if (!(vector instanceof SparseVector)) {
            throw new AkUnsupportedOperationException("not support sparse Hessian matrix computing.");
        }
        double[] data3 = denseMatrix.getData();
        int numRows2 = denseMatrix.numRows();
        int[] indices = ((SparseVector) vector).getIndices();
        double[] values = ((SparseVector) vector).getValues();
        for (int i5 = 0; i5 < values.length; i5++) {
            for (int i6 = 0; i6 < values.length; i6++) {
                int i7 = indices[i5] + (indices[i6] * numRows2);
                data3[i7] = data3[i7] + (values[i5] * values[i6] * secondDerivative);
            }
        }
    }

    private double getEta(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector) {
        return MatVecOp.dot((Vector) tuple3.f2, denseVector);
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public double[] calcSearchValues(Iterable<Tuple3<Double, Double, Vector>> iterable, DenseVector denseVector, DenseVector denseVector2, double d, int i) {
        double[] dArr = new double[i + 1];
        for (Tuple3<Double, Double, Vector> tuple3 : iterable) {
            double doubleValue = ((Double) tuple3.f0).doubleValue();
            double eta = getEta(tuple3, denseVector);
            double eta2 = getEta(tuple3, denseVector2) * d;
            for (int i2 = 0; i2 < i + 1; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (doubleValue * this.unaryLossFunc.loss(eta - (i2 * eta2), ((Double) tuple3.f1).doubleValue()));
            }
        }
        return dArr;
    }
}
