package com.mayabot.nlp.fasttext.train;

import com.mayabot.nlp.fasttext.FastText;
import com.mayabot.nlp.fasttext.Model;
import com.mayabot.nlp.fasttext.args.Args;
import com.mayabot.nlp.fasttext.dictionary.Dictionary;
import com.mayabot.nlp.fasttext.loss.LossName;
import com.mayabot.nlp.fasttext.utils.IOUtilsKt;
import com.mayabot.nlp.fasttext.utils.IntArrayList;
import com.mayabot.nlp.fasttext.utils.LogUtilsKt;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.StringCompanionObject;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: FastTextTrain.kt */
@Metadata(mv = {1, 1, 16}, bv = {1, IOUtilsKt.byteZero, 3}, k = 1, d1 = {"��f\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n��\n\u0002\u0010\t\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010\u000b\n��\n\u0002\u0010\u0002\n��\n\u0002\u0010\u0007\n\u0002\b\u0003\n\u0002\u0010 \n\u0002\u0010\u001c\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��2\u00020\u0001:\u0002./B\u0015\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J\b\u0010\"\u001a\u00020#H\u0002J \u0010$\u001a\u00020%2\u0006\u0010&\u001a\u00020'2\u0006\u0010\u0010\u001a\u00020\u00112\u0006\u0010(\u001a\u00020#H\u0002J\b\u0010&\u001a\u00020'H\u0002J\u001a\u0010)\u001a\u00020%2\u0012\u0010*\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020-0,0+R\u0011\u0010\u0007\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\b\u0010\tR\u0011\u0010\n\u001a\u00020\u000b¢\u0006\b\n��\u001a\u0004\b\f\u0010\rR\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u000e\u0010\u000fR\u000e\u0010\u0010\u001a\u00020\u0011X\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0012\u001a\u00020\u0013¢\u0006\b\n��\u001a\u0004\b\u0014\u0010\u0015R\u000e\u0010\u0016\u001a\u00020\u0013X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0017\u001a\u00020\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0019\u0010\tR\"\u0010\u001a\u001a\n\u0018\u00010\u001bj\u0004\u0018\u0001`\u001cX\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\u001d\u0010\u001e\"\u0004\b\u001f\u0010 R\u000e\u0010!\u001a\u00020\u0013X\u0082\u0004¢\u0006\u0002\n��¨\u00060"}, d2 = {"Lcom/mayabot/nlp/fasttext/train/FastTextTrain;", "", "trainArgs", "Lcom/mayabot/nlp/fasttext/args/Args;", "fastText", "Lcom/mayabot/nlp/fasttext/FastText;", "(Lcom/mayabot/nlp/fasttext/args/Args;Lcom/mayabot/nlp/fasttext/FastText;)V", "args", "getArgs", "()Lcom/mayabot/nlp/fasttext/args/Args;", "dict", "Lcom/mayabot/nlp/fasttext/dictionary/Dictionary;", "getDict", "()Lcom/mayabot/nlp/fasttext/dictionary/Dictionary;", "getFastText", "()Lcom/mayabot/nlp/fasttext/FastText;", "loss", "Lcom/mayabot/nlp/fasttext/train/FastTextTrain$ShareDouble;", "ntokens", "", "getNtokens", "()J", "startTime", "tokenCount", "Ljava/util/concurrent/atomic/AtomicLong;", "getTrainArgs", "trainException", "Ljava/lang/Exception;", "Lkotlin/Exception;", "getTrainException", "()Ljava/lang/Exception;", "setTrainException", "(Ljava/lang/Exception;)V", "wantProcessTotalTokens", "keepTraining", "", "printInfo", "", "progress", "", "stop", "startThreads", "sources", "", "", "Lcom/mayabot/nlp/fasttext/train/SampleLine;", "ShareDouble", "TrainThread", "fastText4j"})
/* loaded from: input_file:com/mayabot/nlp/fasttext/train/FastTextTrain.class */
public final class FastTextTrain {
    private final AtomicLong tokenCount;
    private final ShareDouble loss;
    private long startTime;

    @Nullable
    private Exception trainException;

    @NotNull
    private final Dictionary dict;

    @NotNull
    private final Args args;
    private final long ntokens;
    private final long wantProcessTotalTokens;

    @NotNull
    private final Args trainArgs;

    @NotNull
    private final FastText fastText;

    /* compiled from: FastTextTrain.kt */
    @Metadata(mv = {1, 1, 16}, bv = {1, IOUtilsKt.byteZero, 3}, k = 1, d1 = {"�� \n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0010\u0006\n\u0002\b\u0005\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0010\u0007\n��\u0018��2\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u000e\u0010\b\u001a\u00020\t2\u0006\u0010\n\u001a\u00020\u0003J\u0006\u0010\u000b\u001a\u00020\fR\u001a\u0010\u0002\u001a\u00020\u0003X\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\u0005\u0010\u0006\"\u0004\b\u0007\u0010\u0004¨\u0006\r"}, d2 = {"Lcom/mayabot/nlp/fasttext/train/FastTextTrain$ShareDouble;", "", "value", "", "(D)V", "getValue", "()D", "setValue", "set", "", "v", "toFloat", "", "fastText4j"})
    /* loaded from: input_file:com/mayabot/nlp/fasttext/train/FastTextTrain$ShareDouble.class */
    public static final class ShareDouble {
        private double value;

        public final float toFloat() {
            return (float) this.value;
        }

        public final void set(double d) {
            this.value = d;
        }

        public final double getValue() {
            return this.value;
        }

        public final void setValue(double d) {
            this.value = d;
        }

        public ShareDouble(double d) {
            this.value = d;
        }
    }

    /* compiled from: FastTextTrain.kt */
    @Metadata(mv = {1, 1, 16}, bv = {1, IOUtilsKt.byteZero, 3}, k = 1, d1 = {"��F\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0010\u001c\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0010\t\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n\u0002\b\u0005\b\u0080\u0004\u0018��2\u00020\u0001B\u001b\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005¢\u0006\u0002\u0010\u0007J(\u0010\u0015\u001a\u00020\u00162\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\u001cH\u0002J\b\u0010\u001d\u001a\u00020\u0016H\u0016J(\u0010\u001e\u001a\u00020\u00162\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\u001cH\u0002J0\u0010\u001f\u001a\u00020\u00162\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\u001c2\u0006\u0010 \u001a\u00020\u001cH\u0002R\u001a\u0010\b\u001a\u00020\u0003X\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\t\u0010\n\"\u0004\b\u000b\u0010\fR\u0011\u0010\r\u001a\u00020\u000e¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\u0010R\u0014\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005X\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0011\u001a\u00020\u0012¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��¨\u0006!"}, d2 = {"Lcom/mayabot/nlp/fasttext/train/FastTextTrain$TrainThread;", "Ljava/lang/Runnable;", "threadId", "", "parts", "", "Lcom/mayabot/nlp/fasttext/train/SampleLine;", "(Lcom/mayabot/nlp/fasttext/train/FastTextTrain;ILjava/lang/Iterable;)V", "localTokenCount", "getLocalTokenCount", "()I", "setLocalTokenCount", "(I)V", "ntokens", "", "getNtokens", "()J", "state", "Lcom/mayabot/nlp/fasttext/Model$State;", "getState", "()Lcom/mayabot/nlp/fasttext/Model$State;", "cbow", "", "model", "Lcom/mayabot/nlp/fasttext/Model;", "lr", "", "line", "Lcom/mayabot/nlp/fasttext/utils/IntArrayList;", "run", "skipgram", "supervised", "labels", "fastText4j"})
    /* loaded from: input_file:com/mayabot/nlp/fasttext/train/FastTextTrain$TrainThread.class */
    public final class TrainThread implements Runnable {

        @NotNull
        private final Model.State state;
        private final long ntokens;
        private int localTokenCount;
        private final int threadId;
        private final Iterable<SampleLine> parts;
        final /* synthetic */ FastTextTrain this$0;

        @NotNull
        public final Model.State getState() {
            return this.state;
        }

        public final long getNtokens() {
            return this.ntokens;
        }

        public final int getLocalTokenCount() {
            return this.localTokenCount;
        }

        public final void setLocalTokenCount(int i) {
            this.localTokenCount = i;
        }

        /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
        /* JADX WARN: Code restructure failed: missing block: B:17:0x015a, code lost:
        
            if (r7.localTokenCount <= r7.this$0.getArgs().getLrUpdateRate()) goto L24;
         */
        /* JADX WARN: Code restructure failed: missing block: B:18:0x015d, code lost:
        
            r7.this$0.tokenCount.addAndGet(r7.localTokenCount);
            r7.localTokenCount = 0;
         */
        /* JADX WARN: Code restructure failed: missing block: B:19:0x0176, code lost:
        
            if (r7.threadId != 0) goto L24;
         */
        /* JADX WARN: Code restructure failed: missing block: B:20:0x0179, code lost:
        
            r7.this$0.loss.set(r7.state.getLoss());
         */
        @Override // java.lang.Runnable
        /*
            Code decompiled incorrectly, please refer to instructions dump.
            To view partially-correct add '--show-bad-code' argument
        */
        public void run() {
            /*
                Method dump skipped, instructions count: 434
                To view this dump add '--comments-level debug' option
            */
            throw new UnsupportedOperationException("Method not decompiled: com.mayabot.nlp.fasttext.train.FastTextTrain.TrainThread.run():void");
        }

        private final void supervised(Model.State state, Model model, float f, IntArrayList intArrayList, IntArrayList intArrayList2) {
            if (intArrayList2.size() == 0 || intArrayList.size() == 0) {
                return;
            }
            if (this.this$0.getArgs().getLoss() == LossName.ova) {
                model.update(intArrayList, intArrayList2, Model.Companion.getKAllLabelsAsTarget(), f, state);
            } else {
                model.update(intArrayList, intArrayList2, state.getRng().nextInt(intArrayList2.size()), f, state);
            }
        }

        private final void cbow(Model.State state, Model model, float f, IntArrayList intArrayList) {
            IntArrayList intArrayList2 = new IntArrayList(0, null, 3, null);
            int size = intArrayList.size();
            for (int i = 0; i < size; i++) {
                int nextInt = state.getRng().nextInt(this.this$0.getArgs().getWs()) + 1;
                intArrayList2.clear();
                int i2 = -nextInt;
                if (i2 <= nextInt) {
                    while (true) {
                        if (i2 != 0 && i + i2 >= 0 && i + i2 < intArrayList.size()) {
                            intArrayList2.addAll(this.this$0.getDict().getSubwords(intArrayList.get(i + i2)));
                        }
                        if (i2 != nextInt) {
                            i2++;
                        }
                    }
                }
                model.update(intArrayList2, intArrayList, i, f, state);
            }
        }

        private final void skipgram(Model.State state, Model model, float f, IntArrayList intArrayList) {
            int size = intArrayList.size();
            for (int i = 0; i < size; i++) {
                int nextInt = state.getRng().nextInt(this.this$0.getArgs().getWs()) + 1;
                IntArrayList subwords = this.this$0.getDict().getSubwords(intArrayList.get(i));
                int i2 = -nextInt;
                if (i2 <= nextInt) {
                    while (true) {
                        if (i2 != 0 && i + i2 >= 0 && i + i2 < intArrayList.size()) {
                            model.update(subwords, intArrayList, i + i2, f, state);
                        }
                        if (i2 != nextInt) {
                            i2++;
                        }
                    }
                }
            }
        }

        public TrainThread(FastTextTrain fastTextTrain, @NotNull int i, Iterable<SampleLine> iterable) {
            Intrinsics.checkParameterIsNotNull(iterable, "parts");
            this.this$0 = fastTextTrain;
            this.threadId = i;
            this.parts = iterable;
            this.state = new Model.State(fastTextTrain.getArgs().getDim(), fastTextTrain.getFastText().getOutput().getRow(), fastTextTrain.getTrainArgs().getSeed());
            this.ntokens = fastTextTrain.getDict().getNtokens();
        }
    }

    @Nullable
    public final Exception getTrainException() {
        return this.trainException;
    }

    public final void setTrainException(@Nullable Exception exc) {
        this.trainException = exc;
    }

    @NotNull
    public final Dictionary getDict() {
        return this.dict;
    }

    @NotNull
    public final Args getArgs() {
        return this.args;
    }

    public final long getNtokens() {
        return this.ntokens;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final boolean keepTraining() {
        return this.tokenCount.longValue() < this.wantProcessTotalTokens && this.trainException == null;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final float progress() {
        return this.tokenCount.floatValue() / ((float) this.wantProcessTotalTokens);
    }

    public final void startThreads(@NotNull List<? extends Iterable<SampleLine>> list) {
        Intrinsics.checkParameterIsNotNull(list, "sources");
        int size = list.size();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            arrayList.add(new Thread(new TrainThread(this, i, list.get(i))));
        }
        for (int i2 = 0; i2 < size; i2++) {
            ((Thread) arrayList.get(i2)).start();
        }
        this.dict.getNtokens();
        while (keepTraining()) {
            Thread.sleep(100L);
            if (this.loss.toFloat() >= 0) {
                float progress = progress();
                LogUtilsKt.logger("\r");
                printInfo(progress, this.loss, false);
            }
        }
        for (int i3 = 0; i3 < size; i3++) {
            ((Thread) arrayList.get(i3)).join();
        }
        Exception exc = this.trainException;
        if (exc != null) {
            throw exc;
        }
        LogUtilsKt.logger("\r");
        printInfo(1.0f, this.loss, true);
        LogUtilsKt.loggerln();
        LogUtilsKt.loggerln("Train use time " + (System.currentTimeMillis() - this.startTime) + " ms");
    }

    private final void printInfo(float f, ShareDouble shareDouble, boolean z) {
        float f2 = f;
        double currentTimeMillis = (System.currentTimeMillis() - this.startTime) / 1000;
        double lr = this.trainArgs.getLr() * (1.0d - f2);
        double d = 0.0d;
        long j = 2592000;
        if (f2 > 0 && currentTimeMillis >= 0) {
            f2 *= 100;
            j = (long) ((currentTimeMillis * (100.0f - f2)) / f2);
            d = (this.tokenCount.doubleValue() / currentTimeMillis) / this.trainArgs.getThread();
        }
        long j2 = j / 3600;
        long j3 = (j % 3600) / 60;
        long j4 = (j % 3600) % 60;
        StringBuilder sb = new StringBuilder();
        StringBuilder append = new StringBuilder().append("Progress: ");
        StringCompanionObject stringCompanionObject = StringCompanionObject.INSTANCE;
        Object[] objArr = {Float.valueOf(f2)};
        String format = String.format("%2.2f", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(format, *args)");
        StringBuilder append2 = append.append(format).append("% words/sec/thread: ");
        StringCompanionObject stringCompanionObject2 = StringCompanionObject.INSTANCE;
        Object[] objArr2 = {Double.valueOf(d)};
        String format2 = String.format("%8.0f", Arrays.copyOf(objArr2, objArr2.length));
        Intrinsics.checkExpressionValueIsNotNull(format2, "java.lang.String.format(format, *args)");
        sb.append(append2.append(format2).toString());
        if (!z) {
            StringCompanionObject stringCompanionObject3 = StringCompanionObject.INSTANCE;
            Object[] objArr3 = {Double.valueOf(lr)};
            String format3 = String.format(" lr: %2.5f", Arrays.copyOf(objArr3, objArr3.length));
            Intrinsics.checkExpressionValueIsNotNull(format3, "java.lang.String.format(format, *args)");
            sb.append(format3);
        }
        StringCompanionObject stringCompanionObject4 = StringCompanionObject.INSTANCE;
        Object[] objArr4 = {Float.valueOf(shareDouble.toFloat())};
        String format4 = String.format(" arg.loss: %2.5f", Arrays.copyOf(objArr4, objArr4.length));
        Intrinsics.checkExpressionValueIsNotNull(format4, "java.lang.String.format(format, *args)");
        sb.append(format4);
        if (!z) {
            sb.append(" ETA: " + j2 + "h " + j3 + "m " + j4 + "s");
        }
        LogUtilsKt.logger(sb);
    }

    @NotNull
    public final Args getTrainArgs() {
        return this.trainArgs;
    }

    @NotNull
    public final FastText getFastText() {
        return this.fastText;
    }

    public FastTextTrain(@NotNull Args args, @NotNull FastText fastText) {
        Intrinsics.checkParameterIsNotNull(args, "trainArgs");
        Intrinsics.checkParameterIsNotNull(fastText, "fastText");
        this.trainArgs = args;
        this.fastText = fastText;
        this.tokenCount = new AtomicLong(0L);
        this.loss = new ShareDouble(-1.0d);
        this.startTime = System.currentTimeMillis();
        this.dict = this.fastText.getDict();
        this.args = this.trainArgs;
        this.ntokens = this.dict.getNtokens();
        this.wantProcessTotalTokens = this.args.getEpoch() * this.ntokens;
    }
}
