package com.alibaba.alink.operator.common.optim;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.recommendation.FmTrainParams;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/LocalFmOptimizer.class */
public class LocalFmOptimizer {
    private final List<Tuple3<Double, Double, Vector>> trainData;
    private final double[] lambda;
    private BaseFmTrainBatchOp.FmDataFormat sigmaGii;
    private final double learnRate;
    private final BaseFmTrainBatchOp.LossFunction lossFunc;
    private final int numEpochs;
    private final BaseFmTrainBatchOp.Task task;
    private final double[] y;
    private final double[] vx;
    private final double[] v2x2;
    private long oldTime;
    private final double[] lossCurve;
    protected BaseFmTrainBatchOp.FmDataFormat fmModel = null;
    private double oldLoss = 1.0d;
    private final double[] loss = new double[4];
    private final int[] dim = new int[3];

    public LocalFmOptimizer(List<Tuple3<Double, Double, Vector>> list, Params params) {
        this.numEpochs = ((Integer) params.get(FmTrainParams.NUM_EPOCHS)).intValue();
        this.trainData = list;
        this.y = new double[list.size()];
        this.dim[0] = ((Boolean) params.get(FmTrainParams.WITH_INTERCEPT)).booleanValue() ? 1 : 0;
        this.dim[1] = ((Boolean) params.get(FmTrainParams.WITH_LINEAR_ITEM)).booleanValue() ? 1 : 0;
        this.dim[2] = ((Integer) params.get(FmTrainParams.NUM_FACTOR)).intValue();
        this.vx = new double[this.dim[2]];
        this.v2x2 = new double[this.dim[2]];
        this.lambda = new double[3];
        this.lambda[0] = ((Double) params.get(FmTrainParams.LAMBDA_0)).doubleValue();
        this.lambda[1] = ((Double) params.get(FmTrainParams.LAMBDA_1)).doubleValue();
        this.lambda[2] = ((Double) params.get(FmTrainParams.LAMBDA_2)).doubleValue();
        this.task = (BaseFmTrainBatchOp.Task) params.get(ModelParamName.TASK);
        this.learnRate = ((Double) params.get(FmTrainParams.LEARN_RATE)).doubleValue();
        this.oldTime = System.currentTimeMillis();
        if (this.task.equals(BaseFmTrainBatchOp.Task.REGRESSION)) {
            double max = Math.max(1.0E20d - (-1.0E20d), 1.0d);
            this.lossFunc = new BaseFmTrainBatchOp.SquareLoss(1.0E20d + (max * 0.2d), (-1.0E20d) - (max * 0.2d));
        } else {
            this.lossFunc = new BaseFmTrainBatchOp.LogitLoss();
        }
        this.lossCurve = new double[this.numEpochs * 3];
    }

    public void setWithInitFactors(BaseFmTrainBatchOp.FmDataFormat fmDataFormat) {
        this.fmModel = fmDataFormat;
        this.sigmaGii = new BaseFmTrainBatchOp.FmDataFormat(this.fmModel.factors.length, this.dim, Criteria.INVALID_GAIN);
    }

    public Tuple2<BaseFmTrainBatchOp.FmDataFormat, double[]> optimize() {
        for (int i = 0; i < this.numEpochs; i++) {
            updateFactors();
            calcLossAndEvaluation();
            if (termination(i)) {
                break;
            }
        }
        return Tuple2.of(this.fmModel, this.lossCurve);
    }

    public boolean termination(int i) {
        this.lossCurve[3 * i] = this.loss[0] / this.loss[1];
        this.lossCurve[(3 * i) + 2] = this.loss[3] / this.loss[1];
        if (this.task.equals(BaseFmTrainBatchOp.Task.BINARY_CLASSIFICATION)) {
            this.lossCurve[(3 * i) + 1] = this.loss[2];
            System.out.println("step : " + i + " loss : " + (this.loss[0] / this.loss[1]) + "  auc : " + this.loss[2] + " accuracy : " + (this.loss[3] / this.loss[1]) + " time : " + (System.currentTimeMillis() - this.oldTime));
        } else {
            this.lossCurve[(3 * i) + 1] = this.loss[2] / this.loss[1];
            System.out.println("step : " + i + " loss : " + (this.loss[0] / this.loss[1]) + "  mae : " + (this.loss[2] / this.loss[1]) + " mse : " + (this.loss[3] / this.loss[1]) + " time : " + (System.currentTimeMillis() - this.oldTime));
        }
        this.oldTime = System.currentTimeMillis();
        if (Math.abs(this.oldLoss - (this.loss[0] / this.loss[1])) / this.oldLoss < 1.0E-6d) {
            this.oldLoss = this.loss[0] / this.loss[1];
            return true;
        }
        this.oldLoss = this.loss[0] / this.loss[1];
        return false;
    }

    public void calcLossAndEvaluation() {
        double d = 0.0d;
        for (int i = 0; i < this.y.length; i++) {
            d += this.lossFunc.l(((Double) this.trainData.get(i).f1).doubleValue(), this.y[i]);
        }
        if (this.task.equals(BaseFmTrainBatchOp.Task.REGRESSION)) {
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i2 = 0; i2 < this.y.length; i2++) {
                double doubleValue = this.y[i2] - ((Double) this.trainData.get(i2).f1).doubleValue();
                d2 += Math.abs(doubleValue);
                d3 += doubleValue * doubleValue;
            }
            this.loss[2] = d2;
            this.loss[3] = d3;
        } else {
            Integer[] numArr = new Integer[this.y.length];
            double d4 = 0.0d;
            for (int i3 = 0; i3 < this.y.length; i3++) {
                numArr[i3] = Integer.valueOf(i3);
                if (this.y[i3] > Criteria.INVALID_GAIN && ((Double) this.trainData.get(i3).f1).doubleValue() > 0.5d) {
                    d4 += 1.0d;
                }
                if (this.y[i3] < Criteria.INVALID_GAIN && ((Double) this.trainData.get(i3).f1).doubleValue() < 0.5d) {
                    d4 += 1.0d;
                }
            }
            Arrays.sort(numArr, Comparator.comparingDouble(num -> {
                return this.y[num.intValue()];
            }));
            int i4 = 0;
            int i5 = 0;
            double d5 = 0.0d;
            for (int i6 = 0; i6 < numArr.length; i6++) {
                int intValue = numArr[i6].intValue();
                int i7 = i6 + 1;
                if (((Double) this.trainData.get(intValue).f1).doubleValue() > 0.5d) {
                    i4++;
                    d5 += i7;
                } else {
                    i5++;
                }
            }
            if (i4 == 0 || i5 == 0) {
                this.loss[2] = 0.0d;
            } else {
                this.loss[2] = (d5 - ((0.5d * i4) * (i4 + 1.0d))) / (i4 * i5);
            }
            this.loss[3] = d4;
        }
        this.loss[0] = d;
        this.loss[1] = this.y.length;
    }

    private void updateFactors() {
        int[] iArr;
        double[] data;
        for (int i = 0; i < this.trainData.size(); i++) {
            Tuple3<Double, Double, Vector> tuple3 = this.trainData.get(i);
            Tuple2<Double, double[]> calcY = calcY((Vector) tuple3.f2, this.fmModel, this.dim);
            this.y[i] = ((Double) calcY.f0).doubleValue();
            double dldy = this.lossFunc.dldy(((Double) tuple3.f1).doubleValue(), ((Double) calcY.f0).doubleValue());
            if (tuple3.f2 instanceof SparseVector) {
                iArr = ((SparseVector) tuple3.f2).getIndices();
                data = ((SparseVector) tuple3.f2).getValues();
            } else {
                iArr = new int[((Vector) tuple3.f2).size()];
                for (int i2 = 0; i2 < ((Vector) tuple3.f2).size(); i2++) {
                    iArr[i2] = i2;
                }
                data = ((DenseVector) tuple3.f2).getData();
            }
            double doubleValue = ((Double) tuple3.f0).doubleValue() * this.learnRate;
            if (this.dim[0] > 0) {
                double d = dldy + (this.lambda[0] * this.fmModel.bias);
                this.sigmaGii.bias += d * d;
                this.fmModel.bias -= (doubleValue * d) / Math.sqrt(this.sigmaGii.bias + 1.0E-8d);
            }
            for (int i3 = 0; i3 < iArr.length; i3++) {
                int i4 = iArr[i3];
                for (int i5 = 0; i5 < this.dim[2]; i5++) {
                    double d2 = (dldy * data[i3] * (((double[]) calcY.f1)[i5] - (data[i3] * this.fmModel.factors[i4][i5]))) + (this.lambda[2] * this.fmModel.factors[i4][i5]);
                    double[] dArr = this.sigmaGii.factors[i4];
                    int i6 = i5;
                    dArr[i6] = dArr[i6] + (d2 * d2);
                    double[] dArr2 = this.fmModel.factors[i4];
                    int i7 = i5;
                    dArr2[i7] = dArr2[i7] - ((doubleValue * d2) / Math.sqrt(this.sigmaGii.factors[i4][i5] + 1.0E-8d));
                }
                if (this.dim[1] > 0) {
                    double d3 = (dldy * data[i3]) + (this.lambda[1] * this.fmModel.factors[i4][this.dim[2]]);
                    double[] dArr3 = this.sigmaGii.factors[i4];
                    int i8 = this.dim[2];
                    dArr3[i8] = dArr3[i8] + (d3 * d3);
                    double[] dArr4 = this.fmModel.factors[i4];
                    int i9 = this.dim[2];
                    dArr4[i9] = dArr4[i9] - ((d3 * doubleValue) / Math.sqrt(this.sigmaGii.factors[i4][this.dim[2]] + 1.0E-8d));
                }
            }
        }
    }

    private Tuple2<Double, double[]> calcY(Vector vector, BaseFmTrainBatchOp.FmDataFormat fmDataFormat, int[] iArr) {
        int[] iArr2;
        double[] data;
        if (vector instanceof SparseVector) {
            iArr2 = ((SparseVector) vector).getIndices();
            data = ((SparseVector) vector).getValues();
        } else {
            iArr2 = new int[vector.size()];
            for (int i = 0; i < vector.size(); i++) {
                iArr2[i] = i;
            }
            data = ((DenseVector) vector).getData();
        }
        Arrays.fill(this.vx, Criteria.INVALID_GAIN);
        Arrays.fill(this.v2x2, Criteria.INVALID_GAIN);
        double d = iArr[0] > 0 ? Criteria.INVALID_GAIN + fmDataFormat.bias : 0.0d;
        for (int i2 = 0; i2 < iArr2.length; i2++) {
            int i3 = iArr2[i2];
            double d2 = data[i2];
            if (iArr[1] > 0) {
                d += d2 * fmDataFormat.factors[i3][iArr[2]];
            }
            for (int i4 = 0; i4 < iArr[2]; i4++) {
                double d3 = d2 * fmDataFormat.factors[i3][i4];
                double[] dArr = this.vx;
                int i5 = i4;
                dArr[i5] = dArr[i5] + d3;
                double[] dArr2 = this.v2x2;
                int i6 = i4;
                dArr2[i6] = dArr2[i6] + (d3 * d3);
            }
        }
        for (int i7 = 0; i7 < iArr[2]; i7++) {
            d += 0.5d * ((this.vx[i7] * this.vx[i7]) - this.v2x2[i7]);
        }
        return Tuple2.of(Double.valueOf(d), this.vx);
    }
}
