package com.alibaba.alink.operator.common.linear;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import java.util.Arrays;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/linear/SoftmaxObjFunc.class */
public class SoftmaxObjFunc extends OptimObjFunc {
    private static final long serialVersionUID = -5686349186208428792L;
    private final int k1;

    public SoftmaxObjFunc(Params params) {
        super(params);
        this.k1 = ((Integer) this.params.get(ModelParamName.NUM_CLASSES)).intValue() - 1;
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public double calcLoss(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector) {
        int size = denseVector.size() / this.k1;
        double[] data = denseVector.getData();
        double d = 1.0d;
        double d2 = 0.0d;
        int intValue = ((Double) tuple3.f1).intValue();
        if (tuple3.f2 instanceof DenseVector) {
            double[] data2 = ((DenseVector) tuple3.f2).getData();
            for (int i = 0; i < this.k1; i++) {
                double d3 = 0.0d;
                for (int i2 = 0; i2 < size; i2++) {
                    d3 += data2[i2] * data[(i * size) + i2];
                }
                if (intValue == i) {
                    d2 -= d3;
                }
                d += Math.exp(d3);
            }
        } else {
            int[] indices = ((SparseVector) tuple3.f2).getIndices();
            double[] values = ((SparseVector) tuple3.f2).getValues();
            for (int i3 = 0; i3 < this.k1; i3++) {
                double d4 = 0.0d;
                for (int i4 = 0; i4 < indices.length; i4++) {
                    d4 += values[i4] * data[(i3 * size) + indices[i4]];
                }
                if (intValue == i3) {
                    d2 -= d4;
                }
                d += Math.exp(d4);
            }
        }
        return d2 + Math.log(d);
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public double[] calcSearchValues(Iterable<Tuple3<Double, Double, Vector>> iterable, DenseVector denseVector, DenseVector denseVector2, double d, int i) {
        double[] dArr = new double[i + 1];
        double[] dArr2 = new double[i + 1];
        Tuple2<double[], double[]> of = Tuple2.of(new double[this.k1 + 1], new double[this.k1 + 1]);
        for (Tuple3<Double, Double, Vector> tuple3 : iterable) {
            calcEta(tuple3, denseVector, denseVector2, d, of);
            int intValue = ((Double) tuple3.f1).intValue();
            for (int i2 = 0; i2 < i + 1; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] - (((double[]) of.f0)[intValue] - (i2 * ((double[]) of.f1)[intValue]));
            }
            for (int i4 = 0; i4 < this.k1; i4++) {
                ((double[]) of.f0)[i4] = Math.exp(((double[]) of.f0)[i4]);
                ((double[]) of.f1)[i4] = Math.exp(((double[]) of.f1)[i4]);
            }
            Arrays.fill(dArr2, 0, i + 1, 1.0d);
            for (int i5 = 0; i5 < this.k1; i5++) {
                double d2 = ((double[]) of.f0)[i5];
                for (int i6 = 0; i6 <= i; i6++) {
                    int i7 = i6;
                    dArr2[i7] = dArr2[i7] + d2;
                    d2 /= ((double[]) of.f1)[i5];
                }
            }
            for (int i8 = 0; i8 < i + 1; i8++) {
                int i9 = i8;
                dArr[i9] = dArr[i9] + (Math.log(dArr2[i8]) * ((Double) tuple3.f0).doubleValue());
            }
        }
        return dArr;
    }

    private void calcEta(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseVector denseVector2, double d, Tuple2<double[], double[]> tuple2) {
        int size = denseVector.size() / this.k1;
        double[] data = denseVector.getData();
        double[] data2 = denseVector2.getData();
        double[] dArr = (double[]) tuple2.f0;
        double[] dArr2 = (double[]) tuple2.f1;
        if (tuple3.f2 instanceof DenseVector) {
            double[] data3 = ((DenseVector) tuple3.f2).getData();
            for (int i = 0; i < this.k1; i++) {
                dArr[i] = 0.0d;
                dArr2[i] = 0.0d;
                for (int i2 = 0; i2 < size; i2++) {
                    int i3 = i2 + (i * size);
                    int i4 = i;
                    dArr[i4] = dArr[i4] + (data3[i2] * data[i3]);
                    int i5 = i;
                    dArr2[i5] = dArr2[i5] + (data3[i2] * data2[i3] * d);
                }
            }
            return;
        }
        int[] indices = ((SparseVector) tuple3.f2).getIndices();
        double[] values = ((SparseVector) tuple3.f2).getValues();
        for (int i6 = 0; i6 < this.k1; i6++) {
            dArr[i6] = 0.0d;
            dArr2[i6] = 0.0d;
            for (int i7 = 0; i7 < indices.length; i7++) {
                int i8 = indices[i7] + (i6 * size);
                int i9 = i6;
                dArr[i9] = dArr[i9] + (values[i7] * data[i8]);
                int i10 = i6;
                dArr2[i10] = dArr2[i10] + (values[i7] * data2[i8] * d);
            }
        }
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public void updateGradient(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseVector denseVector2) {
        double[] calcPhi = calcPhi(tuple3, denseVector);
        int intValue = ((Double) tuple3.f1).intValue();
        if (intValue < this.k1) {
            calcPhi[intValue] = calcPhi[intValue] - 1.0d;
        }
        int size = denseVector.size() / this.k1;
        if (!(tuple3.f2 instanceof SparseVector)) {
            double[] data = ((DenseVector) tuple3.f2).getData();
            for (int i = 0; i < this.k1; i++) {
                double doubleValue = calcPhi[i] * ((Double) tuple3.f0).doubleValue();
                int i2 = i * size;
                for (int i3 = 0; i3 < size; i3++) {
                    denseVector2.add(i2 + i3, data[i3] * doubleValue);
                }
            }
            return;
        }
        int[] indices = ((SparseVector) tuple3.f2).getIndices();
        double[] values = ((SparseVector) tuple3.f2).getValues();
        for (int i4 = 0; i4 < this.k1; i4++) {
            double doubleValue2 = calcPhi[i4] * ((Double) tuple3.f0).doubleValue();
            int i5 = i4 * size;
            for (int i6 = 0; i6 < indices.length; i6++) {
                denseVector2.add(i5 + indices[i6], values[i6] * doubleValue2);
            }
        }
    }

    private double[] calcPhi(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector) {
        double[] dArr = new double[this.k1];
        int size = denseVector.size() / this.k1;
        double[] data = denseVector.getData();
        double d = 1.0d;
        if (tuple3.f2 instanceof DenseVector) {
            double[] data2 = ((DenseVector) tuple3.f2).getData();
            for (int i = 0; i < this.k1; i++) {
                double d2 = 0.0d;
                for (int i2 = 0; i2 < size; i2++) {
                    d2 += data2[i2] * data[(i * size) + i2];
                }
                dArr[i] = Math.exp(d2);
                d += dArr[i];
            }
        } else {
            int[] indices = ((SparseVector) tuple3.f2).getIndices();
            double[] values = ((SparseVector) tuple3.f2).getValues();
            for (int i3 = 0; i3 < this.k1; i3++) {
                double d3 = 0.0d;
                for (int i4 = 0; i4 < indices.length; i4++) {
                    d3 += values[i4] * data[(i3 * size) + indices[i4]];
                }
                dArr[i3] = Math.exp(d3);
                d += dArr[i3];
            }
        }
        for (int i5 = 0; i5 < this.k1; i5++) {
            int i6 = i5;
            dArr[i6] = dArr[i6] / d;
        }
        return dArr;
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public void updateHessian(Tuple3<Double, Double, Vector> tuple3, DenseVector denseVector, DenseMatrix denseMatrix) {
        double[] calcPhi = calcPhi(tuple3, denseVector);
        int size = denseVector.size() / this.k1;
        if (tuple3.f2 instanceof DenseVector) {
            double[] data = ((DenseVector) tuple3.f2).getData();
            for (int i = 0; i < this.k1; i++) {
                double d = calcPhi[i] - (calcPhi[i] * calcPhi[i]);
                int i2 = i * size;
                for (int i3 = 0; i3 < size; i3++) {
                    for (int i4 = 0; i4 < size; i4++) {
                        denseMatrix.add(i3 + i2, i4 + i2, data[i3] * data[i4] * d);
                    }
                }
            }
            for (int i5 = 0; i5 < this.k1; i5++) {
                for (int i6 = i5 + 1; i6 < this.k1; i6++) {
                    double d2 = (-calcPhi[i5]) * calcPhi[i6];
                    int i7 = i5 * size;
                    int i8 = i6 * size;
                    for (int i9 = 0; i9 < size; i9++) {
                        for (int i10 = 0; i10 < size; i10++) {
                            double d3 = data[i9] * data[i10] * d2;
                            denseMatrix.add(i9 + i7, i10 + i8, d3);
                            denseMatrix.add(i10 + i8, i9 + i7, d3);
                        }
                    }
                }
            }
            return;
        }
        int[] indices = ((SparseVector) tuple3.f2).getIndices();
        double[] values = ((SparseVector) tuple3.f2).getValues();
        for (int i11 = 0; i11 < this.k1; i11++) {
            double d4 = calcPhi[i11] - (calcPhi[i11] * calcPhi[i11]);
            int i12 = i11 * size;
            for (int i13 = 0; i13 < indices.length; i13++) {
                int i14 = indices[i13] + i12;
                double d5 = values[i13] * d4;
                for (int i15 = 0; i15 < indices.length; i15++) {
                    denseMatrix.add(i14, indices[i15] + i12, values[i15] * d5);
                }
            }
        }
        for (int i16 = 0; i16 < this.k1; i16++) {
            for (int i17 = i16 + 1; i17 < this.k1; i17++) {
                double d6 = (-calcPhi[i16]) * calcPhi[i17];
                int i18 = i16 * size;
                int i19 = i17 * size;
                for (int i20 = 0; i20 < indices.length; i20++) {
                    for (int i21 = 0; i21 < indices.length; i21++) {
                        double d7 = values[i20] * values[i21] * d6;
                        denseMatrix.add(indices[i20] + i18, indices[i21] + i19, d7);
                        denseMatrix.add(indices[i21] + i19, indices[i20] + i18, d7);
                    }
                }
            }
        }
    }

    @Override // com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc
    public boolean hasSecondDerivative() {
        return true;
    }
}
