package com.alibaba.alink.operator.common.feature.AutoCross;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.linear.unarylossfunc.LogLossFunc;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossObjFunc.class */
public class AutoCrossObjFunc extends OptimObjFunc {
    private static final long serialVersionUID = 4901712689789156832L;
    private LogLossFunc lossFunc;

    public AutoCrossObjFunc(Params params) {
        super(params);
        this.lossFunc = new LogLossFunc();
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public double calcLoss(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector) {
        return this.lossFunc.loss(getEta(tuple3, denseVector), ((Double) tuple3.f1).doubleValue());
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public void updateGradient(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseVector denseVector2) {
        double derivative = this.lossFunc.derivative(getEta(tuple3, denseVector), ((Double) tuple3.f1).doubleValue());
        int size = denseVector.size() - denseVector2.size();
        double[] data = denseVector2.getData();
        SparseVector sparseVector = (SparseVector) tuple3.f2;
        for (int i = 0; i < sparseVector.getIndices().length; i++) {
            int i2 = sparseVector.getIndices()[i];
            if (i2 >= size) {
                int i3 = i2 - size;
                data[i3] = data[i3] + (derivative * sparseVector.getValues()[i]);
            }
        }
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public void updateHessian(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseMatrix denseMatrix) {
        throw new RuntimeException("do not support hessian.");
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public boolean hasSecondDerivative() {
        return false;
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public double[] calcSearchValues(Iterable<Tuple3<Double, Double, Vector>> iterable, DenseVector denseVector, DenseVector denseVector2, double d, int i) {
        double[] dArr = new double[i + 1];
        int size = denseVector.size() - denseVector2.size();
        double[] dArr2 = new double[denseVector.size()];
        System.arraycopy(denseVector2.getData(), 0, dArr2, size, denseVector2.size());
        DenseVector denseVector3 = new DenseVector(dArr2);
        for (Tuple3<Double, Double, Vector> tuple3 : iterable) {
            double doubleValue = ((Double) tuple3.f0).doubleValue();
            double eta = getEta(tuple3, denseVector);
            double eta2 = getEta(tuple3, denseVector3) * d;
            for (int i2 = 0; i2 < i + 1; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (doubleValue * this.lossFunc.loss(eta - (i2 * eta2), ((Double) tuple3.f1).doubleValue()));
            }
        }
        return dArr;
    }

    private double getEta(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector) {
        return MatVecOp.dot((Vector) tuple3.f2, denseVector);
    }
}
