package com.alibaba.alink.operator.batch.huge.line;

import com.alibaba.alink.common.utils.ExpTableArray;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.apache.flink.api.java.tuple.Tuple2;

/* loaded from: input_file:com/alibaba/alink/operator/batch/huge/line/LinePullAndTrainOperation.class */
public class LinePullAndTrainOperation {
    private final int negaTime;
    private final Long[] nsPool;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/huge/line/LinePullAndTrainOperation$AliasSampling.class */
    public static class AliasSampling {
        private final int[] alias;
        private final double[] prob;
        private final Random rand;
        private final int weightSize;

        public AliasSampling(double[] dArr, int i) {
            Tuple2<int[], double[]> initAlias = initAlias(dArr);
            this.alias = (int[]) initAlias.f0;
            this.prob = (double[]) initAlias.f1;
            this.rand = new Random(i);
            this.weightSize = dArr.length;
        }

        static Tuple2<int[], double[]> initAlias(double[] dArr) {
            int length = dArr.length;
            int[] iArr = new int[length];
            double[] dArr2 = new double[length];
            double[] dArr3 = new double[length];
            int[] iArr2 = new int[length];
            int[] iArr3 = new int[length];
            double d = 0.0d;
            for (double d2 : dArr) {
                d += d2;
            }
            int i = 0;
            int i2 = 0;
            for (int i3 = 0; i3 < length; i3++) {
                dArr3[i3] = (dArr[i3] * length) / d;
            }
            for (int i4 = length - 1; i4 >= 0; i4--) {
                if (dArr3[i4] < 1.0d) {
                    int i5 = i2;
                    i2++;
                    iArr3[i5] = i4;
                } else {
                    int i6 = i;
                    i++;
                    iArr2[i6] = i4;
                }
            }
            while (i2 > 0 && i > 0) {
                i2--;
                int i7 = iArr3[i2];
                i--;
                int i8 = iArr2[i];
                dArr2[i7] = dArr3[i7];
                iArr[i7] = i8;
                dArr3[i8] = (dArr3[i8] + dArr3[i7]) - 1.0d;
                if (dArr3[i8] < 1.0d) {
                    i2++;
                    iArr3[i2] = i8;
                } else {
                    i++;
                    iArr2[i] = i8;
                }
            }
            while (i != 0) {
                i--;
                dArr2[iArr2[i]] = 1.0d;
            }
            while (i2 != 0) {
                i2--;
                dArr2[iArr3[i2]] = 1.0d;
            }
            return Tuple2.of(iArr, dArr2);
        }

        public int sampling() {
            int floor = (int) Math.floor(this.weightSize * this.rand.nextDouble());
            return this.rand.nextDouble() < this.prob[floor] ? floor : this.alias[floor];
        }
    }

    public LinePullAndTrainOperation(int i, Long[] lArr) {
        this.negaTime = i;
        this.nsPool = lArr;
    }

    public static long negativeSampling(Random random, Long[] lArr) {
        double nextDouble = random.nextDouble() * (lArr.length - 1);
        double floor = Math.floor(nextDouble);
        return Math.round((lArr[r0 + 1].longValue() - lArr[r0].longValue()) * (nextDouble - floor)) + lArr[0 + ((int) floor)].longValue();
    }

    public void getIndexes(int i, double d, List<Number[]> list, Set<Long> set) {
        Random random = new Random();
        random.setSeed(i);
        int size = list.size();
        long[][] jArr = new long[size][2];
        double[] dArr = new double[size];
        int i2 = 0;
        for (Number[] numberArr : list) {
            long[] jArr2 = new long[2];
            jArr2[0] = ((Long) numberArr[0]).longValue();
            jArr2[1] = ((Long) numberArr[1]).longValue();
            jArr[i2] = jArr2;
            dArr[i2] = Float.valueOf(((Float) numberArr[2]).floatValue()).doubleValue();
            i2++;
        }
        AliasSampling aliasSampling = new AliasSampling(dArr, i);
        int round = (int) Math.round(d * size);
        for (int i3 = 0; i3 < round; i3++) {
            int sampling = aliasSampling.sampling();
            long j = jArr[sampling][0];
            long j2 = jArr[sampling][1];
            set.add(Long.valueOf(j));
            set.add(Long.valueOf(j2));
            for (int i4 = 0; i4 < this.negaTime; i4++) {
                set.add(Long.valueOf(negativeSampling(random, this.nsPool)));
            }
        }
    }

    public void train(int i, double d, double d2, boolean z, int i2, double d3, float[] fArr, float[] fArr2, Map<Long, Integer> map, List<Number[]> list) {
        Random random = new Random();
        random.setSeed(i);
        int size = list.size();
        long[][] jArr = new long[size][2];
        double[] dArr = new double[size];
        int i3 = 0;
        for (Number[] numberArr : list) {
            long[] jArr2 = new long[2];
            jArr2[0] = ((Long) numberArr[0]).longValue();
            jArr2[1] = ((Long) numberArr[1]).longValue();
            jArr[i3] = jArr2;
            dArr[i3] = numberArr[2].doubleValue();
            i3++;
        }
        AliasSampling aliasSampling = new AliasSampling(dArr, i);
        float[] fArr3 = new float[i2];
        int round = (int) Math.round(d3 * size);
        double[] dArr2 = {Criteria.INVALID_GAIN};
        if (z) {
            for (int i4 = 0; i4 < round; i4++) {
                double max = d * Math.max(1.0d - ((i4 * 1.0d) / round), d2);
                Arrays.fill(fArr3, 0.0f);
                int sampling = aliasSampling.sampling();
                long j = jArr[sampling][0];
                float[] vec = getVec(fArr, map, j, i2);
                long j2 = jArr[sampling][1];
                float[] vec2 = getVec(fArr, map, j2, i2);
                update(vec, vec2, fArr3, 1, max, dArr2);
                setVec(fArr, vec2, map.get(Long.valueOf(j2)).intValue(), i2);
                for (int i5 = 0; i5 < this.negaTime; i5++) {
                    long negativeSampling = negativeSampling(random, this.nsPool);
                    float[] vec3 = getVec(fArr, map, negativeSampling, i2);
                    update(vec, vec3, fArr3, 0, max, dArr2);
                    setVec(fArr, vec3, map.get(Long.valueOf(negativeSampling)).intValue(), i2);
                }
                for (int i6 = 0; i6 < i2; i6++) {
                    int i7 = i6;
                    vec[i7] = vec[i7] + fArr3[i6];
                }
                setVec(fArr, vec, map.get(Long.valueOf(j)).intValue(), i2);
            }
            return;
        }
        for (int i8 = 0; i8 < round; i8++) {
            double max2 = d * Math.max(1.0d - ((i8 * 1.0d) / round), d2);
            Arrays.fill(fArr3, 0.0f);
            int sampling2 = aliasSampling.sampling();
            long j3 = jArr[sampling2][0];
            float[] vec4 = getVec(fArr, map, j3, i2);
            long j4 = jArr[sampling2][1];
            float[] vec5 = getVec(fArr2, map, j4, i2);
            update(vec4, vec5, fArr3, 1, max2, dArr2);
            setVec(fArr2, vec5, map.get(Long.valueOf(j4)).intValue(), i2);
            for (int i9 = 0; i9 < this.negaTime; i9++) {
                long negativeSampling2 = negativeSampling(random, this.nsPool);
                float[] vec6 = getVec(fArr2, map, negativeSampling2, i2);
                update(vec4, vec6, fArr3, 0, max2, dArr2);
                setVec(fArr2, vec6, map.get(Long.valueOf(negativeSampling2)).intValue(), i2);
            }
            for (int i10 = 0; i10 < i2; i10++) {
                int i11 = i10;
                vec4[i11] = vec4[i11] + fArr3[i10];
            }
            setVec(fArr, vec4, map.get(Long.valueOf(j3)).intValue(), i2);
        }
    }

    private static float[] getVec(float[] fArr, Map<Long, Integer> map, long j, int i) {
        float[] fArr2 = new float[i];
        System.arraycopy(fArr, map.get(Long.valueOf(j)).intValue() * i, fArr2, 0, i);
        return fArr2;
    }

    private static void setVec(float[] fArr, float[] fArr2, int i, int i2) {
        System.arraycopy(fArr2, 0, fArr, i * i2, i2);
    }

    protected static void update(float[] fArr, float[] fArr2, float[] fArr3, int i, double d, double[] dArr) {
        float sigmoid = ExpTableArray.sigmoid(floatDot(fArr, fArr2));
        if (i == 1) {
            dArr[0] = dArr[0] - ExpTableArray.log(sigmoid);
        } else {
            dArr[0] = dArr[0] - ExpTableArray.log(1.0f - sigmoid);
        }
        float f = (float) ((i - sigmoid) * d);
        floatAxpy(f, fArr2, fArr3);
        floatAxpy(f, fArr, fArr2);
    }

    public static float floatDot(float[] fArr, float[] fArr2) {
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            f += fArr[i] * fArr2[i];
        }
        return f;
    }

    public static void floatAxpy(float f, float[] fArr, float[] fArr2) {
        int length = fArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            fArr2[i2] = fArr2[i2] + (f * fArr[i]);
        }
    }
}
