package com.mayabot.nlp.fasttext.loss;

import com.mayabot.nlp.fasttext.Model;
import com.mayabot.nlp.fasttext.blas.DenseVector;
import com.mayabot.nlp.fasttext.blas.Matrix;
import com.mayabot.nlp.fasttext.utils.IOUtilsKt;
import com.mayabot.nlp.fasttext.utils.IntArrayList;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: SoftmaxLoss.kt */
@Metadata(mv = {1, 1, 16}, bv = {1, IOUtilsKt.byteZero, 3}, k = 1, d1 = {"��8\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0010\u000b\n��\u0018��2\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u0010\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0007\u001a\u00020\bH\u0016J0\u0010\t\u001a\u00020\n2\u0006\u0010\u000b\u001a\u00020\f2\u0006\u0010\r\u001a\u00020\u000e2\u0006\u0010\u0007\u001a\u00020\b2\u0006\u0010\u000f\u001a\u00020\n2\u0006\u0010\u0010\u001a\u00020\u0011H\u0016¨\u0006\u0012"}, d2 = {"Lcom/mayabot/nlp/fasttext/loss/SoftmaxLoss;", "Lcom/mayabot/nlp/fasttext/loss/Loss;", "wo", "Lcom/mayabot/nlp/fasttext/blas/Matrix;", "(Lcom/mayabot/nlp/fasttext/blas/Matrix;)V", "computeOutput", "", "state", "Lcom/mayabot/nlp/fasttext/Model$State;", "forward", "", "targets", "Lcom/mayabot/nlp/fasttext/utils/IntArrayList;", "targetIndex", "", "lr", "backprop", "", "fastText4j"})
/* loaded from: input_file:com/mayabot/nlp/fasttext/loss/SoftmaxLoss.class */
public final class SoftmaxLoss extends Loss {
    @Override // com.mayabot.nlp.fasttext.loss.Loss
    public void computeOutput(@NotNull Model.State state) {
        Intrinsics.checkParameterIsNotNull(state, "state");
        DenseVector output = state.getOutput();
        output.mul(getWo(), state.getHidden());
        float f = output.get(0);
        float f2 = 0.0f;
        int length = output.length();
        for (int i = 0; i < length; i++) {
            f = Math.max(output.get(i), f);
        }
        for (int i2 = 0; i2 < length; i2++) {
            output.set(i2, (float) Math.exp(output.get(i2) - f));
            f2 += output.get(i2);
        }
        for (int i3 = 0; i3 < length; i3++) {
            output.set(i3, output.get(i3) / f2);
        }
    }

    @Override // com.mayabot.nlp.fasttext.loss.Loss
    public float forward(@NotNull IntArrayList intArrayList, int i, @NotNull Model.State state, float f, boolean z) {
        Intrinsics.checkParameterIsNotNull(intArrayList, "targets");
        Intrinsics.checkParameterIsNotNull(state, "state");
        computeOutput(state);
        int i2 = intArrayList.get(i);
        if (z) {
            int row = getWo().getRow();
            int i3 = 0;
            while (i3 < row) {
                float f2 = f * ((i3 == i2 ? 1.0f : 0.0f) - state.getOutput().get(i3));
                state.getGrad().addRow(getWo(), i3, f2);
                getWo().addVectorToRow(state.getHidden(), i3, f2);
                i3++;
            }
        }
        return -Loss.Companion.log(state.getOutput().get(i2));
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public SoftmaxLoss(@NotNull Matrix matrix) {
        super(matrix);
        Intrinsics.checkParameterIsNotNull(matrix, "wo");
    }
}
