package com.mayabot.nlp.fasttext;

import com.mayabot.nlp.blas.DenseVector;
import com.mayabot.nlp.blas.Matrix;
import com.mayabot.nlp.common.IntArrayList;
import com.mayabot.nlp.fasttext.loss.Loss;
import java.util.List;
import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.random.Random;
import kotlin.random.RandomKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: Model.kt */
@Metadata(mv = {1, 4, 1}, bv = {1, 0, 3}, k = 1, d1 = {"��P\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\t\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\b\n��\n\u0002\u0010\u0007\n��\n\u0002\u0010!\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\u0018�� $2\u00020\u0001:\u0002$%B%\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u0012\u0006\u0010\u0005\u001a\u00020\u0006\u0012\u0006\u0010\u0007\u001a\u00020\b¢\u0006\u0002\u0010\tJ\u0018\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010\u0015\u001a\u00020\u0016H\u0002J8\u0010\u0017\u001a\u00020\u00122\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010\u0018\u001a\u00020\u00192\u0006\u0010\u001a\u001a\u00020\u001b2\u0010\u0010\u001c\u001a\f\u0012\u0004\u0012\u00020\u001e0\u001dj\u0002`\u001f2\u0006\u0010\u0015\u001a\u00020\u0016J.\u0010 \u001a\u00020\u00122\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010!\u001a\u00020\u00142\u0006\u0010\"\u001a\u00020\u00192\u0006\u0010#\u001a\u00020\u001b2\u0006\u0010\u0015\u001a\u00020\u0016R\u0011\u0010\u0005\u001a\u00020\u0006¢\u0006\b\n��\u001a\u0004\b\n\u0010\u000bR\u0011\u0010\u0007\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\f\u0010\rR\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u000e\u0010\u000fR\u0011\u0010\u0004\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0010\u0010\u000f¨\u0006&"}, d2 = {"Lcom/mayabot/nlp/fasttext/Model;", "", "wi", "Lcom/mayabot/nlp/blas/Matrix;", "wo", "loss", "Lcom/mayabot/nlp/fasttext/loss/Loss;", "normalizeGradient", "", "(Lcom/mayabot/nlp/blas/Matrix;Lcom/mayabot/nlp/blas/Matrix;Lcom/mayabot/nlp/fasttext/loss/Loss;Z)V", "getLoss", "()Lcom/mayabot/nlp/fasttext/loss/Loss;", "getNormalizeGradient", "()Z", "getWi", "()Lcom/mayabot/nlp/blas/Matrix;", "getWo", "computeHidden", "", "input", "Lcom/mayabot/nlp/common/IntArrayList;", "state", "Lcom/mayabot/nlp/fasttext/Model$State;", "predict", "k", "", "threshold", "", "heap", "", "Lcom/mayabot/nlp/fasttext/ScoreIdPair;", "Lcom/mayabot/nlp/fasttext/Predictions;", "update", "targets", "targetIndex", "lr", "Companion", "State", "mynlp"})
/* loaded from: input_file:com/mayabot/nlp/fasttext/Model.class */
public final class Model {

    @NotNull
    private final Matrix wi;

    @NotNull
    private final Matrix wo;

    @NotNull
    private final Loss loss;
    private final boolean normalizeGradient;

    @NotNull
    public static final Companion Companion = new Companion(null);
    private static final int kUnlimitedPredictions = -1;
    private static final int kAllLabelsAsTarget = -1;

    /* compiled from: Model.kt */
    @Metadata(mv = {1, 4, 1}, bv = {1, 0, 3}, k = 1, d1 = {"��\u0014\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0005\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002R\u0014\u0010\u0003\u001a\u00020\u0004X\u0086D¢\u0006\b\n��\u001a\u0004\b\u0005\u0010\u0006R\u0014\u0010\u0007\u001a\u00020\u0004X\u0086D¢\u0006\b\n��\u001a\u0004\b\b\u0010\u0006¨\u0006\t"}, d2 = {"Lcom/mayabot/nlp/fasttext/Model$Companion;", "", "()V", "kAllLabelsAsTarget", "", "getKAllLabelsAsTarget", "()I", "kUnlimitedPredictions", "getKUnlimitedPredictions", "mynlp"})
    /* loaded from: input_file:com/mayabot/nlp/fasttext/Model$Companion.class */
    public static final class Companion {
        public final int getKUnlimitedPredictions() {
            return Model.kUnlimitedPredictions;
        }

        public final int getKAllLabelsAsTarget() {
            return Model.kAllLabelsAsTarget;
        }

        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    /* compiled from: Model.kt */
    @Metadata(mv = {1, 4, 1}, bv = {1, 0, 3}, k = 1, d1 = {"��0\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0010\b\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\u0007\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n��\u0018��2\u00020\u0001B\u001d\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u0012\u0006\u0010\u0005\u001a\u00020\u0003¢\u0006\u0002\u0010\u0006J\u000e\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\r\u001a\u00020\u000eR\u0011\u0010\u0007\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\t\u0010\nR\u0011\u0010\u000b\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\f\u0010\nR\u0011\u0010\r\u001a\u00020\u000e8F¢\u0006\u0006\u001a\u0004\b\u000f\u0010\u0010R\u000e\u0010\u0011\u001a\u00020\u000eX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0012\u001a\u00020\u0003X\u0082\u000e¢\u0006\u0002\n��R\u0011\u0010\u0013\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\u0014\u0010\nR\u0011\u0010\u0015\u001a\u00020\u0016¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0018¨\u0006\u001b"}, d2 = {"Lcom/mayabot/nlp/fasttext/Model$State;", "", "hiddenSize", "", "outputSize", "seed", "(III)V", "grad", "Lcom/mayabot/nlp/blas/DenseVector;", "getGrad", "()Lcom/mayabot/nlp/blas/DenseVector;", "hidden", "getHidden", "loss", "", "getLoss", "()F", "lossValue", "nexamples", "output", "getOutput", "rng", "Lkotlin/random/Random;", "getRng", "()Lkotlin/random/Random;", "incrementNExamples", "", "mynlp"})
    /* loaded from: input_file:com/mayabot/nlp/fasttext/Model$State.class */
    public static final class State {
        private float lossValue;
        private int nexamples;

        @NotNull
        private final DenseVector hidden;

        @NotNull
        private final DenseVector output;

        @NotNull
        private final DenseVector grad;

        @NotNull
        private final Random rng;

        @NotNull
        public final DenseVector getHidden() {
            return this.hidden;
        }

        @NotNull
        public final DenseVector getOutput() {
            return this.output;
        }

        @NotNull
        public final DenseVector getGrad() {
            return this.grad;
        }

        @NotNull
        public final Random getRng() {
            return this.rng;
        }

        public final float getLoss() {
            return this.lossValue / this.nexamples;
        }

        public final void incrementNExamples(float f) {
            this.lossValue += f;
            this.nexamples++;
        }

        public State(int i, int i2, int i3) {
            this.hidden = new DenseVector(i);
            this.output = new DenseVector(i2);
            this.grad = new DenseVector(i);
            this.rng = RandomKt.Random(i3);
        }
    }

    private final void computeHidden(IntArrayList intArrayList, State state) {
        DenseVector hidden = state.getHidden();
        hidden.zero();
        int[] buffer = intArrayList.getBuffer();
        int size = intArrayList.size();
        for (int i = 0; i < size; i++) {
            Matrix.DefaultImpls.addRowToVector$default(this.wi, hidden, buffer[i], null, 4, null);
        }
        hidden.timesAssign(Float.valueOf(1.0f / intArrayList.size()));
    }

    public final void predict(@NotNull IntArrayList intArrayList, int i, float f, @NotNull List<ScoreIdPair> list, @NotNull State state) {
        Intrinsics.checkNotNullParameter(intArrayList, "input");
        Intrinsics.checkNotNullParameter(list, "heap");
        Intrinsics.checkNotNullParameter(state, "state");
        if ((i == kUnlimitedPredictions ? this.wo.getRow() : i) == 0) {
            throw new RuntimeException("k needs to be 1 or higher");
        }
        computeHidden(intArrayList, state);
        this.loss.predict(i, f, list, state);
    }

    public final void update(@NotNull IntArrayList intArrayList, @NotNull IntArrayList intArrayList2, int i, float f, @NotNull State state) {
        Intrinsics.checkNotNullParameter(intArrayList, "input");
        Intrinsics.checkNotNullParameter(intArrayList2, "targets");
        Intrinsics.checkNotNullParameter(state, "state");
        if (intArrayList.size() == 0) {
            return;
        }
        computeHidden(intArrayList, state);
        DenseVector grad = state.getGrad();
        grad.zero();
        state.incrementNExamples(this.loss.forward(intArrayList2, i, state, f, true));
        if (this.normalizeGradient) {
            grad.timesAssign(Float.valueOf(1.0f / intArrayList.size()));
        }
        int[] buffer = intArrayList.getBuffer();
        int size = intArrayList.size();
        for (int i2 = 0; i2 < size; i2++) {
            this.wi.addVectorToRow(grad, buffer[i2], 1.0f);
        }
    }

    @NotNull
    public final Matrix getWi() {
        return this.wi;
    }

    @NotNull
    public final Matrix getWo() {
        return this.wo;
    }

    @NotNull
    public final Loss getLoss() {
        return this.loss;
    }

    public final boolean getNormalizeGradient() {
        return this.normalizeGradient;
    }

    public Model(@NotNull Matrix matrix, @NotNull Matrix matrix2, @NotNull Loss loss, boolean z) {
        Intrinsics.checkNotNullParameter(matrix, "wi");
        Intrinsics.checkNotNullParameter(matrix2, "wo");
        Intrinsics.checkNotNullParameter(loss, "loss");
        this.wi = matrix;
        this.wo = matrix2;
        this.loss = loss;
        this.normalizeGradient = z;
    }
}
