package com.alibaba.alink.operator.common.tree.parallelcart.loss;

import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.Arrays;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/loss/LambdaLoss.class */
public final class LambdaLoss implements RankingLossFunc {
    private static final int MAX_POS = 10000;
    private final int atT;
    private final int numIterUpdatePos;
    private final LambdaType lambdaType;
    private final Integer[] sortedIndices;
    private double[] maxDcgInverse;
    private final int[] labelCnt = new int[MAX_TARGET];
    public static final ParamInfo<Integer> AT_T;
    public static final ParamInfo<Integer> NUM_ITER_UPDATE_POS;
    private static final double LOG2 = Math.log(2.0d);
    private static final int MAX_TARGET = 31;
    private static final double[] GAIN_TABLE = new double[MAX_TARGET];
    private static final double[] DISCOUNT_TABLE = new double[10000];

    /* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/loss/LambdaLoss$LambdaType.class */
    public enum LambdaType {
        DCG,
        NDCG
    }

    public LambdaLoss(Params params, LambdaType lambdaType, int[] iArr, double[] dArr, double[] dArr2) {
        this.atT = ((Integer) params.get(AT_T)).intValue();
        this.numIterUpdatePos = ((Integer) params.get(NUM_ITER_UPDATE_POS)).intValue();
        this.lambdaType = lambdaType;
        this.sortedIndices = new Integer[dArr.length];
        initial(iArr, dArr, dArr2);
    }

    public void initial(int[] iArr, double[] dArr, double[] dArr2) {
        if (this.lambdaType != LambdaType.NDCG) {
            return;
        }
        this.maxDcgInverse = new double[iArr.length - 1];
        for (int i = 0; i < iArr.length - 1; i++) {
            int i2 = iArr[i];
            int i3 = iArr[i + 1];
            int i4 = i3 - i2;
            Arrays.fill(this.labelCnt, 0);
            for (int i5 = i2; i5 < i3; i5++) {
                int[] iArr2 = this.labelCnt;
                int i6 = (int) dArr[i5];
                iArr2[i6] = iArr2[i6] + 1;
            }
            int min = Math.min(i4, this.atT);
            int i7 = 30;
            for (int i8 = 0; i8 < min; i8++) {
                while (i7 > 0 && this.labelCnt[i7] == 0) {
                    i7--;
                }
                if (i7 < 0) {
                    break;
                }
                double[] dArr3 = this.maxDcgInverse;
                int i9 = i;
                dArr3[i9] = dArr3[i9] + (GAIN_TABLE[i7] * DISCOUNT_TABLE[i8]);
                int[] iArr3 = this.labelCnt;
                int i10 = i7;
                iArr3[i10] = iArr3[i10] - 1;
            }
            this.maxDcgInverse[i] = Math.max(1.0d, this.maxDcgInverse[i]);
            if (this.maxDcgInverse[i] > Criteria.INVALID_GAIN) {
                this.maxDcgInverse[i] = 1.0d / this.maxDcgInverse[i];
            }
        }
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.loss.RankingLossFunc
    public void gradients(int[] iArr, int i, int i2, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5) {
        int i3;
        int i4;
        int i5;
        int i6;
        int i7;
        int i8;
        int i9 = iArr[i];
        int i10 = iArr[i + 1];
        int i11 = i10 - i9;
        int i12 = i2 / this.numIterUpdatePos;
        if (i2 > 0 && i2 % this.numIterUpdatePos == 0) {
            for (int i13 = i9; i13 < i10; i13++) {
                this.sortedIndices[i13] = Integer.valueOf(i13);
            }
            Arrays.sort(this.sortedIndices, i9, i10, (num, num2) -> {
                return Double.compare(dArr2[num2.intValue()], dArr2[num.intValue()]);
            });
        } else if (i2 == 0) {
            for (int i14 = i9; i14 < i10; i14++) {
                this.sortedIndices[i14] = Integer.valueOf(i14);
            }
        }
        Arrays.fill(dArr3, i9, i10, Criteria.INVALID_GAIN);
        Arrays.fill(dArr4, i9, i10, Criteria.INVALID_GAIN);
        Arrays.fill(dArr5, i9, i10, Criteria.INVALID_GAIN);
        for (int i15 = 0; i15 < i11 - 1; i15++) {
            int intValue = this.sortedIndices[i9 + i15].intValue();
            int i16 = (int) dArr[intValue];
            for (int i17 = i15 + 1; i17 < i11; i17++) {
                int intValue2 = this.sortedIndices[i9 + i17].intValue();
                int i18 = (int) dArr[intValue2];
                if (i16 != i18) {
                    if (i16 > i18) {
                        i3 = intValue;
                        i4 = i16;
                        i5 = i15;
                        i6 = intValue2;
                        i7 = i18;
                        i8 = i17;
                    } else {
                        i3 = intValue2;
                        i4 = i18;
                        i5 = i17;
                        i6 = intValue;
                        i7 = i16;
                        i8 = i15;
                    }
                    double abs = i12 == 0 ? GAIN_TABLE[i4] - GAIN_TABLE[i7] : (GAIN_TABLE[i4] - GAIN_TABLE[i7]) * Math.abs(DISCOUNT_TABLE[i5] - DISCOUNT_TABLE[i8]);
                    if (this.lambdaType == LambdaType.NDCG) {
                        abs *= this.maxDcgInverse[i];
                    }
                    double sigmoid = sigmoid(dArr2[i6] - dArr2[i3]);
                    double d = sigmoid * (1.0d - sigmoid);
                    double d2 = sigmoid * abs;
                    double d3 = d * abs;
                    int i19 = i3;
                    dArr4[i19] = dArr4[i19] + d2;
                    int i20 = i6;
                    dArr4[i20] = dArr4[i20] - d2;
                    int i21 = i3;
                    dArr3[i21] = dArr3[i21] + abs;
                    int i22 = i6;
                    dArr3[i22] = dArr3[i22] + abs;
                    int i23 = i3;
                    dArr5[i23] = dArr5[i23] + d3;
                    int i24 = i6;
                    dArr5[i24] = dArr5[i24] + d3;
                }
            }
        }
    }

    public static double sigmoid(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    static {
        for (int i = 0; i < MAX_TARGET; i++) {
            GAIN_TABLE[i] = (1 << i) - 1.0d;
        }
        for (int i2 = 0; i2 < 10000; i2++) {
            DISCOUNT_TABLE[i2] = (1.0d / Math.log(2.0d + i2)) * LOG2;
        }
        AT_T = ParamInfoFactory.createParamInfo("atT", Integer.class).setHasDefaultValue(3).build();
        NUM_ITER_UPDATE_POS = ParamInfoFactory.createParamInfo("numIterUpdatePos", Integer.class).setHasDefaultValue(10).build();
    }
}
