package com.alibaba.alink.operator.common.linear;

import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/linear/AftRegObjFunc.class */
public class AftRegObjFunc extends OptimObjFunc {
    private static final long serialVersionUID = -25113151208677581L;

    public AftRegObjFunc(Params params) {
        super(params);
    }

    public static double getDotProduct(Vector vector, DenseVector denseVector) {
        double[] data = denseVector.getData();
        double d = 0.0d;
        if (vector instanceof SparseVector) {
            int[] indices = ((SparseVector) vector).getIndices();
            double[] values = ((SparseVector) vector).getValues();
            for (int i = 0; i < indices.length; i++) {
                if (indices[i] < data.length) {
                    d += values[i] * data[indices[i]];
                }
            }
        } else {
            double[] data2 = ((DenseVector) vector).getData();
            int min = Math.min(data2.length, data.length);
            for (int i2 = 0; i2 < min; i2++) {
                d += data2[i2] * data[i2];
            }
        }
        return d;
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public double calcLoss(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector) {
        double doubleValue = (((Double) tuple3.f1).doubleValue() - getDotProduct((Vector) tuple3.f2, denseVector)) / Math.exp(denseVector.get(denseVector.size() - 1));
        return (((Double) tuple3.f0).doubleValue() * (denseVector.get(denseVector.size() - 1) - doubleValue)) + Math.exp(doubleValue);
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public void updateGradient(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseVector denseVector2) {
        double exp = Math.exp(denseVector.get(denseVector.size() - 1));
        double doubleValue = (((Double) tuple3.f1).doubleValue() - getDotProduct((Vector) tuple3.f2, denseVector)) / exp;
        double doubleValue2 = ((Double) tuple3.f0).doubleValue() - Math.exp(doubleValue);
        double d = doubleValue2 / exp;
        if (tuple3.f2 instanceof SparseVector) {
            int[] indices = ((SparseVector) tuple3.f2).getIndices();
            double[] values = ((SparseVector) tuple3.f2).getValues();
            for (int i = 0; i < indices.length; i++) {
                denseVector2.add(indices[i], values[i] * d);
            }
        } else {
            double[] data = ((DenseVector) tuple3.f2).getData();
            for (int i2 = 0; i2 < data.length; i2++) {
                denseVector2.add(i2, data[i2] * d);
            }
        }
        denseVector2.add(denseVector.size() - 1, ((Double) tuple3.f0).doubleValue() + (doubleValue2 * doubleValue));
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public void updateHessian(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseMatrix denseMatrix) {
        double exp = Math.exp(denseVector.get(denseVector.size() - 1));
        double doubleValue = (((Double) tuple3.f1).doubleValue() - getDotProduct((Vector) tuple3.f2, denseVector)) / exp;
        double exp2 = Math.exp(doubleValue) / (exp * exp);
        if (tuple3.f2 instanceof SparseVector) {
            int[] indices = ((SparseVector) tuple3.f2).getIndices();
            double[] values = ((SparseVector) tuple3.f2).getValues();
            for (int i = 0; i < indices.length; i++) {
                double d = values[i] * exp2;
                for (int i2 = 0; i2 < i; i2++) {
                    double d2 = d * values[i2];
                    denseMatrix.add(indices[i], indices[i2], d2);
                    denseMatrix.add(indices[i2], indices[i], d2);
                }
                denseMatrix.add(indices[i], indices[i], d * values[i]);
            }
        } else {
            double[] data = ((DenseVector) tuple3.f2).getData();
            for (int i3 = 0; i3 < data.length; i3++) {
                double d3 = data[i3] * exp2;
                for (int i4 = 0; i4 < i3; i4++) {
                    double d4 = d3 * data[i4];
                    denseMatrix.add(i3, i4, d4);
                    denseMatrix.add(i4, i3, d4);
                }
                denseMatrix.add(i3, i3, d3 * data[i3]);
            }
        }
        denseMatrix.add(denseVector.size() - 1, denseVector.size() - 1, doubleValue * ((Math.exp(doubleValue) * (1.0d + doubleValue)) - ((Double) tuple3.f0).doubleValue()));
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public boolean hasSecondDerivative() {
        return true;
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public Tuple2<Double, Double> calcObjValue(Iterable<Tuple3<Double, Double, Vector>> iterable, DenseVector denseVector) {
        double d = 0.0d;
        double d2 = 0.0d;
        Iterator<Tuple3<Double, Double, Vector>> it = iterable.iterator();
        while (it.hasNext()) {
            d2 += calcLoss(it.next(), denseVector);
            d += 1.0d;
        }
        if (Criteria.INVALID_GAIN != d) {
            d2 /= d;
        }
        if (Criteria.INVALID_GAIN != this.l1) {
            d2 += this.l1 * denseVector.normL1();
        }
        if (Criteria.INVALID_GAIN != this.l2) {
            d2 += this.l2 * MatVecOp.dot(denseVector, denseVector);
        }
        return new Tuple2<>(Double.valueOf(d2), Double.valueOf(d));
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public double calcGradient(Iterable<Tuple3<Double, Double, Vector>> iterable, DenseVector denseVector, DenseVector denseVector2) {
        double d = 0.0d;
        for (int i = 0; i < denseVector2.size(); i++) {
            denseVector2.set(i, Criteria.INVALID_GAIN);
        }
        Iterator<Tuple3<Double, Double, Vector>> it = iterable.iterator();
        while (it.hasNext()) {
            d += 1.0d;
            updateGradient(it.next(), denseVector, denseVector2);
        }
        if (d > Criteria.INVALID_GAIN) {
            denseVector2.scaleEqual(1.0d / d);
        }
        if (Criteria.INVALID_GAIN != this.l2) {
            denseVector2.plusScaleEqual(denseVector, this.l2 * 2.0d);
        }
        if (Criteria.INVALID_GAIN != this.l1) {
            double[] data = denseVector.getData();
            for (int i2 = 0; i2 < denseVector.size(); i2++) {
                denseVector2.add(i2, Math.signum(data[i2]) * this.l1);
            }
        }
        return d;
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public Tuple2<Double, Double> calcHessianGradientLoss(Iterable<Tuple3<Double, Double, Vector>> iterable, DenseVector denseVector, DenseMatrix denseMatrix, DenseVector denseVector2) {
        if (!hasSecondDerivative()) {
            throw new AkUnsupportedOperationException("loss function can't support second derivative, newton precondition can not work.");
        }
        int size = denseVector2.size();
        for (int i = 0; i < size; i++) {
            denseVector2.set(i, Criteria.INVALID_GAIN);
            for (int i2 = 0; i2 < size; i2++) {
                denseMatrix.set(i, i2, Criteria.INVALID_GAIN);
            }
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (Tuple3<Double, Double, Vector> tuple3 : iterable) {
            updateHessian(tuple3, denseVector, denseMatrix);
            d += 1.0d;
            updateGradient(tuple3, denseVector, denseVector2);
            d2 += calcLoss(tuple3, denseVector);
        }
        if (Criteria.INVALID_GAIN != this.l2) {
            double d3 = this.l2 * 2.0d * d;
            denseVector2.plusScaleEqual(denseVector, d3);
            for (int i3 = 0; i3 < denseMatrix.numRows(); i3++) {
                denseMatrix.add(i3, i3, d3);
            }
        }
        if (Criteria.INVALID_GAIN != this.l1) {
            double d4 = this.l1 * d;
            double[] data = denseVector.getData();
            for (int i4 = 0; i4 < denseVector.size(); i4++) {
                denseVector2.add(i4, Math.signum(data[i4]) * d4);
            }
        }
        return Tuple2.of(Double.valueOf(d), Double.valueOf(d2));
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public double[] calcSearchValues(Iterable<Tuple3<Double, Double, Vector>> iterable, DenseVector denseVector, DenseVector denseVector2, double d, int i) {
        double[] dArr = new double[i + 1];
        DenseVector[] denseVectorArr = new DenseVector[i + 1];
        denseVectorArr[0] = denseVector.mo136clone();
        DenseVector scale = denseVector2.scale(d);
        for (int i2 = 1; i2 < i + 1; i2++) {
            denseVectorArr[i2] = denseVectorArr[i2 - 1].minus((Vector) scale);
        }
        for (Tuple3<Double, Double, Vector> tuple3 : iterable) {
            for (int i3 = 0; i3 < i + 1; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + calcLoss(tuple3, denseVectorArr[i3]);
            }
        }
        return dArr;
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public double[] constraintCalcSearchValues(List<Tuple3<Double, Double, Vector>> list, DenseVector denseVector, DenseVector denseVector2, double d, int i) {
        double[] dArr = new double[i + 1];
        double[] data = denseVector.getData();
        double[] data2 = denseVector2.getData();
        int length = data.length;
        DenseVector denseVector3 = new DenseVector(length);
        double[] data3 = denseVector3.getData();
        for (int i2 = 0; i2 < i + 1; i2++) {
            double d2 = d * i2;
            for (int i3 = 0; i3 < length; i3++) {
                double d3 = data[i3] - (d2 * data2[i3]);
                if (d3 * data[i3] < Criteria.INVALID_GAIN) {
                    d3 = 0.0d;
                }
                data3[i3] = d3;
            }
            Iterator<Tuple3<Double, Double, Vector>> it = list.iterator();
            while (it.hasNext()) {
                int i4 = i2;
                dArr[i4] = dArr[i4] + calcLoss(it.next(), denseVector3);
            }
        }
        return dArr;
    }
}
