package com.alibaba.alink.operator.common.optim.objfunc;

import com.alibaba.alink.common.exceptions.AkUnimplementedOperationException;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.linear.AftRegObjFunc;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.linear.SoftmaxObjFunc;
import com.alibaba.alink.operator.common.linear.UnaryLossObjFunc;
import com.alibaba.alink.operator.common.linear.unarylossfunc.LogLossFunc;
import com.alibaba.alink.operator.common.linear.unarylossfunc.PerceptronLossFunc;
import com.alibaba.alink.operator.common.linear.unarylossfunc.SmoothHingeLossFunc;
import com.alibaba.alink.operator.common.linear.unarylossfunc.SquareLossFunc;
import com.alibaba.alink.operator.common.linear.unarylossfunc.SvrLossFunc;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.regression.LinearSvrTrainParams;
import com.alibaba.alink.params.shared.linear.HasL1;
import com.alibaba.alink.params.shared.linear.HasL2;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/objfunc/OptimObjFunc.class */
public abstract class OptimObjFunc implements Serializable {
    private static final long serialVersionUID = -4624127324724005715L;
    protected final double l1;
    protected final double l2;
    protected Params params;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc$1, reason: invalid class name */
    /* loaded from: input_file:com/alibaba/alink/operator/common/optim/objfunc/OptimObjFunc$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType = new int[LinearModelType.values().length];

        static {
            try {
                $SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType[LinearModelType.LinearReg.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType[LinearModelType.SVR.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType[LinearModelType.LR.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType[LinearModelType.SVM.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType[LinearModelType.Perceptron.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType[LinearModelType.AFT.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType[LinearModelType.Softmax.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    public OptimObjFunc(Params params) {
        if (null == params) {
            this.params = new Params();
        } else {
            this.params = params;
        }
        this.l1 = ((Double) this.params.get(HasL1.L_1)).doubleValue();
        this.l2 = ((Double) this.params.get(HasL2.L_2)).doubleValue();
    }

    public double getL1() {
        return this.l1;
    }

    public double getL2() {
        return this.l2;
    }

    public abstract double calcLoss(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector);

    public abstract void updateGradient(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseVector denseVector2);

    public abstract void updateHessian(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseMatrix denseMatrix);

    public abstract boolean hasSecondDerivative();

    public Tuple2<Double, Double> calcObjValue(Iterable<Tuple3<Double, Double, Vector>> iterable, DenseVector denseVector) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (Tuple3<Double, Double, Vector> tuple3 : iterable) {
            d2 += calcLoss(tuple3, denseVector) * ((Double) tuple3.f0).doubleValue();
            d += ((Double) tuple3.f0).doubleValue();
        }
        return finalizeObjValue(denseVector, d2, d);
    }

    public Tuple2<Double, Double> finalizeObjValue(DenseVector denseVector, double d, double d2) {
        if (Criteria.INVALID_GAIN != d2) {
            d /= d2;
        }
        if (Criteria.INVALID_GAIN != this.l1) {
            d += this.l1 * denseVector.normL1();
        }
        if (Criteria.INVALID_GAIN != this.l2) {
            d += this.l2 * MatVecOp.dot(denseVector, denseVector);
        }
        return new Tuple2<>(Double.valueOf(d), Double.valueOf(d2));
    }

    public double calcGradient(Iterable<Tuple3<Double, Double, Vector>> iterable, DenseVector denseVector, DenseVector denseVector2) {
        double d = 0.0d;
        Arrays.fill(denseVector2.getData(), Criteria.INVALID_GAIN);
        for (Tuple3<Double, Double, Vector> tuple3 : iterable) {
            if (tuple3.f2 instanceof SparseVector) {
                ((SparseVector) tuple3.f2).setSize(denseVector.size());
            }
            d += ((Double) tuple3.f0).doubleValue();
            updateGradient(tuple3, denseVector, denseVector2);
        }
        finalizeGradient(denseVector, denseVector2, d);
        return d;
    }

    public void finalizeGradient(DenseVector denseVector, DenseVector denseVector2, double d) {
        if (d > Criteria.INVALID_GAIN) {
            denseVector2.scaleEqual(1.0d / d);
        }
        if (Criteria.INVALID_GAIN != this.l2) {
            denseVector2.plusScaleEqual(denseVector, this.l2 * 2.0d);
        }
        if (Criteria.INVALID_GAIN != this.l1) {
            double[] data = denseVector.getData();
            for (int i = 0; i < denseVector.size(); i++) {
                denseVector2.add(i, Math.signum(data[i]) * this.l1);
            }
        }
    }

    public Tuple2<Double, Double> calcHessianGradientLoss(Iterable<Tuple3<Double, Double, Vector>> iterable, DenseVector denseVector, DenseMatrix denseMatrix, DenseVector denseVector2) {
        if (!hasSecondDerivative()) {
            throw new AkUnsupportedOperationException("loss function can't support second derivative, newton precondition can not work.");
        }
        Arrays.fill(denseVector2.getData(), Criteria.INVALID_GAIN);
        Arrays.fill(denseMatrix.getData(), Criteria.INVALID_GAIN);
        double d = 0.0d;
        double d2 = 0.0d;
        for (Tuple3<Double, Double, Vector> tuple3 : iterable) {
            updateHessian(tuple3, denseVector, denseMatrix);
            d += ((Double) tuple3.f0).doubleValue();
            updateGradient(tuple3, denseVector, denseVector2);
            d2 += calcLoss(tuple3, denseVector);
        }
        finalizeHessianGradientLoss(denseVector, denseMatrix, denseVector2, d);
        return Tuple2.of(Double.valueOf(d), Double.valueOf(d2));
    }

    public void finalizeHessianGradientLoss(DenseVector denseVector, DenseMatrix denseMatrix, DenseVector denseVector2, double d) {
        if (Criteria.INVALID_GAIN != this.l1) {
            double d2 = this.l1 * d;
            double[] data = denseVector.getData();
            for (int i = 0; i < denseVector.size(); i++) {
                denseVector2.add(i, Math.signum(data[i]) * d2);
            }
        }
        if (Criteria.INVALID_GAIN != this.l2) {
            double d3 = this.l2 * 2.0d * d;
            denseVector2.plusScaleEqual(denseVector, d3);
            for (int i2 = 0; i2 < denseMatrix.numRows(); i2++) {
                denseMatrix.add(i2, i2, d3);
            }
        }
    }

    public double[] calcSearchValues(Iterable<Tuple3<Double, Double, Vector>> iterable, DenseVector denseVector, DenseVector denseVector2, double d, int i) {
        double[] dArr = new double[i + 1];
        DenseVector[] denseVectorArr = new DenseVector[i + 1];
        denseVectorArr[0] = denseVector.mo136clone();
        DenseVector scale = denseVector2.scale(d);
        for (int i2 = 1; i2 < i + 1; i2++) {
            denseVectorArr[i2] = denseVectorArr[i2 - 1].minus((Vector) scale);
        }
        for (Tuple3<Double, Double, Vector> tuple3 : iterable) {
            for (int i3 = 0; i3 < i + 1; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + (calcLoss(tuple3, denseVectorArr[i3]) * ((Double) tuple3.f0).doubleValue());
            }
        }
        return dArr;
    }

    public double[] constraintCalcSearchValues(List<Tuple3<Double, Double, Vector>> list, DenseVector denseVector, DenseVector denseVector2, double d, int i) {
        double[] dArr = new double[i + 1];
        double[] data = denseVector.getData();
        double[] data2 = denseVector2.getData();
        int length = data.length;
        DenseVector denseVector3 = new DenseVector(length);
        double[] data3 = denseVector3.getData();
        for (int i2 = 0; i2 < i + 1; i2++) {
            double d2 = d * i2;
            for (int i3 = 0; i3 < length; i3++) {
                double d3 = data[i3] - (d2 * data2[i3]);
                if (d3 * data[i3] < Criteria.INVALID_GAIN) {
                    d3 = 0.0d;
                }
                data3[i3] = d3;
            }
            for (Tuple3<Double, Double, Vector> tuple3 : list) {
                int i4 = i2;
                dArr[i4] = dArr[i4] + (calcLoss(tuple3, denseVector3) * ((Double) tuple3.f0).doubleValue());
            }
        }
        return dArr;
    }

    public static OptimObjFunc getObjFunction(LinearModelType linearModelType, Params params) {
        OptimObjFunc softmaxObjFunc;
        switch (AnonymousClass1.$SwitchMap$com$alibaba$alink$operator$common$linear$LinearModelType[linearModelType.ordinal()]) {
            case 1:
                softmaxObjFunc = new UnaryLossObjFunc(new SquareLossFunc(), params);
                break;
            case 2:
                softmaxObjFunc = new UnaryLossObjFunc(new SvrLossFunc(((Double) params.get(LinearSvrTrainParams.TAU)).doubleValue()), params);
                break;
            case 3:
                softmaxObjFunc = new UnaryLossObjFunc(new LogLossFunc(), params);
                break;
            case 4:
                softmaxObjFunc = new UnaryLossObjFunc(new SmoothHingeLossFunc(), params);
                break;
            case 5:
                softmaxObjFunc = new UnaryLossObjFunc(new PerceptronLossFunc(), params);
                break;
            case TableUtil.DISPLAY_SIZE /* 6 */:
                softmaxObjFunc = new AftRegObjFunc(params);
                break;
            case 7:
                softmaxObjFunc = new SoftmaxObjFunc(params);
                break;
            default:
                throw new AkUnimplementedOperationException("Linear model type is Not implemented yet!");
        }
        return softmaxObjFunc;
    }
}
