package com.alibaba.alink.operator.common.tree.parallelcart.loss;

import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.regression.GBRankParams;
import java.util.Arrays;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/parallelcart/loss/GBRankLoss.class */
public final class GBRankLoss implements RankingLossFunc {
    double tau;
    double p;

    public GBRankLoss(Params params, int[] iArr, double[] dArr, double[] dArr2) {
        this.tau = ((Double) params.get(GBRankParams.TAU)).doubleValue();
        this.p = ((Double) params.get(GBRankParams.P)).doubleValue();
        initial(iArr, dArr, dArr2);
    }

    public void initial(int[] iArr, double[] dArr, double[] dArr2) {
        Arrays.fill(dArr2, Criteria.INVALID_GAIN);
        int length = iArr.length - 1;
        for (int i = 0; i < length; i++) {
            int i2 = iArr[i];
            int i3 = iArr[i + 1] - i2;
            for (int i4 = 0; i4 < i3 - 1; i4++) {
                for (int i5 = i4 + 1; i5 < i3; i5++) {
                    if (((int) dArr[i4 + i2]) != ((int) dArr[i5 + i2])) {
                        int i6 = i4 + i2;
                        dArr2[i6] = dArr2[i6] + 1.0d;
                        int i7 = i5 + i2;
                        dArr2[i7] = dArr2[i7] + 1.0d;
                    }
                }
            }
        }
    }

    @Override // com.alibaba.alink.operator.common.tree.parallelcart.loss.RankingLossFunc
    public void gradients(int[] iArr, int i, int i2, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5) {
        int i3;
        int i4;
        int i5 = iArr[i];
        int i6 = iArr[i + 1];
        int i7 = i6 - i5;
        Arrays.fill(dArr4, i5, i6, Criteria.INVALID_GAIN);
        Arrays.fill(dArr5, i5, i6, Criteria.INVALID_GAIN);
        for (int i8 = 0; i8 < i7 - 1; i8++) {
            for (int i9 = i8 + 1; i9 < i7; i9++) {
                int i10 = (int) dArr[i8 + i5];
                int i11 = (int) dArr[i9 + i5];
                if (i10 != i11) {
                    if (i10 > i11) {
                        i3 = i8;
                        i4 = i9;
                    } else {
                        i3 = i9;
                        i4 = i8;
                    }
                    double sigmoid = LambdaLoss.sigmoid((((this.p > 1.0d ? Math.pow(this.p, dArr[i3 + i5]) - Math.pow(this.p, dArr[i4 + i5]) : dArr[i3 + i5] - dArr[i4 + i5]) * this.tau) - dArr2[i3 + i5]) + dArr2[i4 + i5]);
                    double d = sigmoid * (1.0d - sigmoid);
                    int i12 = i3 + i5;
                    dArr4[i12] = dArr4[i12] + sigmoid;
                    int i13 = i4 + i5;
                    dArr4[i13] = dArr4[i13] - sigmoid;
                    int i14 = i3 + i5;
                    dArr5[i14] = dArr5[i14] + d;
                    int i15 = i4 + i5;
                    dArr5[i15] = dArr5[i15] + d;
                }
            }
        }
    }
}
