package com.alibaba.alink.operator.common.classification.ann;

import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.MatVecOp;
import java.util.function.BiFunction;

/* loaded from: input_file:com/alibaba/alink/operator/common/classification/ann/SoftmaxLayerModelWithCrossEntropyLoss.class */
public class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel implements AnnLossFunction {
    @Override // com.alibaba.alink.operator.common.classification.ann.AnnLossFunction
    public double loss(DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseMatrix denseMatrix3) {
        int numRows = denseMatrix.numRows();
        MatVecOp.apply(denseMatrix, denseMatrix2, denseMatrix3, (BiFunction<Double, Double, Double>) (d, d2) -> {
            return Double.valueOf(d2.doubleValue() * Math.log(d.doubleValue()));
        });
        double sum = (-(1.0d / numRows)) * denseMatrix3.sum();
        MatVecOp.apply(denseMatrix, denseMatrix2, denseMatrix3, (BiFunction<Double, Double, Double>) (d3, d4) -> {
            return Double.valueOf(d3.doubleValue() - d4.doubleValue());
        });
        return sum;
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.LayerModel
    public void resetModel(DenseVector denseVector, int i) {
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.LayerModel
    public void eval(DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        int numRows = denseMatrix.numRows();
        for (int i = 0; i < numRows; i++) {
            double d = -1.7976931348623157E308d;
            for (int i2 = 0; i2 < denseMatrix.numCols(); i2++) {
                double d2 = denseMatrix.get(i, i2);
                if (d2 > d) {
                    d = d2;
                }
            }
            double d3 = 0.0d;
            for (int i3 = 0; i3 < denseMatrix.numCols(); i3++) {
                double exp = Math.exp(denseMatrix.get(i, i3) - d);
                denseMatrix2.set(i, i3, exp);
                d3 += exp;
            }
            for (int i4 = 0; i4 < denseMatrix.numCols(); i4++) {
                denseMatrix2.set(i, i4, denseMatrix2.get(i, i4) / d3);
            }
        }
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.LayerModel
    public void computePrevDelta(DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseMatrix denseMatrix3) {
        throw new AkUnclassifiedErrorException("SoftmaxLayerModelWithCrossEntropyLoss should be the last layer.");
    }

    @Override // com.alibaba.alink.operator.common.classification.ann.LayerModel
    public void grad(DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseVector denseVector, int i) {
    }
}
