package com.alibaba.alink.operator.common.evaluation;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.operator.common.evaluation.ClassificationEvaluationUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/BinaryMetricsSummary.class */
public final class BinaryMetricsSummary implements BaseMetricsSummary<BinaryClassMetrics, BinaryMetricsSummary> {
    private static final long serialVersionUID = 4614108912380382179L;
    private static double PROBABILITY_INTERVAL = 0.001d;
    private static double PROBABILITY_ERROR = 1.0E-5d;
    Object[] labels;
    long total;
    long[] positiveBin;
    long[] negativeBin;
    private double decisionThreshold;
    double logLoss;

    public BinaryMetricsSummary() {
    }

    public BinaryMetricsSummary(long[] jArr, long[] jArr2, Object[] objArr, double d, long j) {
        this(jArr, jArr2, objArr, 0.5d, d, j);
    }

    public BinaryMetricsSummary(long[] jArr, long[] jArr2, Object[] objArr, double d, double d2, long j) {
        this.positiveBin = jArr;
        this.negativeBin = jArr2;
        this.labels = objArr;
        this.decisionThreshold = d;
        this.logLoss = d2;
        this.total = j;
    }

    @Override // com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary
    public BinaryMetricsSummary merge(BinaryMetricsSummary binaryMetricsSummary) {
        if (null == binaryMetricsSummary) {
            return this;
        }
        AkPreconditions.checkState(Arrays.equals(this.labels, binaryMetricsSummary.labels), "The labels are not the same!");
        for (int i = 0; i < this.positiveBin.length; i++) {
            long[] jArr = this.positiveBin;
            int i2 = i;
            jArr[i2] = jArr[i2] + binaryMetricsSummary.positiveBin[i];
        }
        for (int i3 = 0; i3 < this.negativeBin.length; i3++) {
            long[] jArr2 = this.negativeBin;
            int i4 = i3;
            jArr2[i4] = jArr2[i4] + binaryMetricsSummary.negativeBin[i3];
        }
        this.logLoss += binaryMetricsSummary.logLoss;
        this.total += binaryMetricsSummary.total;
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary
    public BinaryClassMetrics toMetrics() {
        String[] strArr = new String[this.labels.length];
        for (int i = 0; i < this.labels.length; i++) {
            strArr[i] = this.labels[i].toString();
        }
        Params params = new Params();
        Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> extractMatrixThreCurve = extractMatrixThreCurve(this.positiveBin, this.negativeBin, this.total);
        setCurveAreaParams(params, (EvaluationCurve[]) extractMatrixThreCurve.f2);
        Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> sample = sample(PROBABILITY_INTERVAL, extractMatrixThreCurve);
        setCurvePointsParams(params, sample);
        ConfusionMatrix[] confusionMatrixArr = (ConfusionMatrix[]) sample.f0;
        setComputationsArrayParams(params, (double[]) sample.f1, (ConfusionMatrix[]) sample.f0);
        ClassificationEvaluationUtil.setLoglossParams(params, this.logLoss, this.total);
        setMiddleThreParams(params, confusionMatrixArr[getMiddleThresholdIndex((double[]) sample.f1, this.decisionThreshold)], strArr);
        return new BinaryClassMetrics(params);
    }

    public static void setMiddleThreParams(Params params, ConfusionMatrix confusionMatrix, String[] strArr) {
        params.set((ParamInfo<ParamInfo<Double>>) BinaryClassMetrics.PRECISION, (ParamInfo<Double>) ClassificationEvaluationUtil.Computations.PRECISION.computer.apply(confusionMatrix, 0));
        params.set((ParamInfo<ParamInfo<Double>>) BinaryClassMetrics.RECALL, (ParamInfo<Double>) ClassificationEvaluationUtil.Computations.RECALL.computer.apply(confusionMatrix, 0));
        params.set((ParamInfo<ParamInfo<Double>>) BinaryClassMetrics.F1, (ParamInfo<Double>) ClassificationEvaluationUtil.Computations.F1.computer.apply(confusionMatrix, 0));
        ClassificationEvaluationUtil.setClassificationCommonParams(params, confusionMatrix, strArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void setCurvePointsParams(Params params, Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> tuple3) {
        params.set((ParamInfo<ParamInfo<double[][]>>) BinaryClassMetrics.ROC_CURVE, (ParamInfo<double[][]>) ((EvaluationCurve[]) tuple3.f2)[0].getXYArray());
        params.set((ParamInfo<ParamInfo<double[][]>>) BinaryClassMetrics.PRECISION_RECALL_CURVE, (ParamInfo<double[][]>) ((EvaluationCurve[]) tuple3.f2)[1].getXYArray());
        params.set((ParamInfo<ParamInfo<double[][]>>) BinaryClassMetrics.LIFT_CHART, (ParamInfo<double[][]>) ((EvaluationCurve[]) tuple3.f2)[2].getXYArray());
        params.set((ParamInfo<ParamInfo<double[][]>>) BinaryClassMetrics.LORENZ_CURVE, (ParamInfo<double[][]>) ((EvaluationCurve[]) tuple3.f2)[3].getXYArray());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void setCurveAreaParams(Params params, EvaluationCurve[] evaluationCurveArr) {
        params.set((ParamInfo<ParamInfo<Double>>) BinaryClassMetrics.AUC, (ParamInfo<Double>) Double.valueOf(evaluationCurveArr[0].calcArea()));
        params.set((ParamInfo<ParamInfo<Double>>) BinaryClassMetrics.PRC, (ParamInfo<Double>) Double.valueOf(evaluationCurveArr[1].calcArea()));
        params.set((ParamInfo<ParamInfo<Double>>) BinaryClassMetrics.KS, (ParamInfo<Double>) Double.valueOf(evaluationCurveArr[0].calcKs()));
        params.set((ParamInfo<ParamInfo<Double>>) BinaryClassMetrics.GINI, (ParamInfo<Double>) Double.valueOf((evaluationCurveArr[3].calcArea() - 0.5d) / 0.5d));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void setComputationsArrayParams(Params params, double[] dArr, ConfusionMatrix[] confusionMatrixArr) {
        params.set((ParamInfo<ParamInfo<double[]>>) BinaryClassMetrics.THRESHOLD_ARRAY, (ParamInfo<double[]>) dArr);
        double[][] dArr2 = new double[ClassificationEvaluationUtil.Computations.values().length][confusionMatrixArr.length];
        for (int i = 0; i < confusionMatrixArr.length; i++) {
            for (ClassificationEvaluationUtil.Computations computations : ClassificationEvaluationUtil.Computations.values()) {
                dArr2[computations.ordinal()][i] = computations.computer.apply(confusionMatrixArr[i], 0).doubleValue();
            }
        }
        for (ClassificationEvaluationUtil.Computations computations2 : ClassificationEvaluationUtil.Computations.values()) {
            params.set((ParamInfo<ParamInfo<double[]>>) computations2.arrayParamInfo, (ParamInfo<double[]>) dArr2[computations2.ordinal()]);
        }
    }

    /* JADX WARN: Type inference failed for: r4v1, types: [long[], long[][]] */
    /* JADX WARN: Type inference failed for: r4v11, types: [long[], long[][]] */
    static Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> extractMatrixThreCurve(long[] jArr, long[] jArr2, long j) {
        ArrayList arrayList = new ArrayList();
        long j2 = 0;
        long j3 = 0;
        for (int i = 0; i < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER; i++) {
            if (0 != jArr[i] || 0 != jArr2[i] || i == ClassificationEvaluationUtil.DETAIL_BIN_NUMBER / 2) {
                arrayList.add(Integer.valueOf(i));
                j2 += jArr[i];
                j3 += jArr2[i];
            }
        }
        AkPreconditions.checkState(j3 + j2 == j, "The effective number in bins must be equal to total!");
        int size = arrayList.size();
        int i2 = size + 1;
        double d = 1.0d / ClassificationEvaluationUtil.DETAIL_BIN_NUMBER;
        EvaluationCurvePoint[] evaluationCurvePointArr = new EvaluationCurvePoint[i2];
        EvaluationCurvePoint[] evaluationCurvePointArr2 = new EvaluationCurvePoint[i2];
        EvaluationCurvePoint[] evaluationCurvePointArr3 = new EvaluationCurvePoint[i2];
        EvaluationCurvePoint[] evaluationCurvePointArr4 = new EvaluationCurvePoint[i2];
        ConfusionMatrix[] confusionMatrixArr = new ConfusionMatrix[i2];
        double[] dArr = new double[i2];
        long j4 = 0;
        long j5 = 0;
        for (int i3 = 1; i3 < i2; i3++) {
            int intValue = ((Integer) arrayList.get(size - i3)).intValue();
            j4 += jArr[intValue];
            j5 += jArr2[intValue];
            dArr[i3] = intValue * d;
            confusionMatrixArr[i3] = new ConfusionMatrix((long[][]) new long[]{new long[]{j4, j5}, new long[]{j2 - j4, j3 - j5}});
            double d2 = j2 == 0 ? 1.0d : (1.0d * j4) / j2;
            double d3 = j3 == 0 ? 1.0d : (1.0d * j5) / j3;
            double d4 = j4 + j4 == 0 ? 1.0d : (1.0d * j4) / (j4 + j5);
            double d5 = (1.0d * (j4 + j5)) / j;
            evaluationCurvePointArr[i3] = new EvaluationCurvePoint(d3, d2, dArr[i3]);
            evaluationCurvePointArr2[i3] = new EvaluationCurvePoint(d2, d4, dArr[i3]);
            evaluationCurvePointArr3[i3] = new EvaluationCurvePoint(d5, j4, dArr[i3]);
            evaluationCurvePointArr4[i3] = new EvaluationCurvePoint(d5, d2, dArr[i3]);
        }
        dArr[0] = 1.0d;
        confusionMatrixArr[0] = new ConfusionMatrix((long[][]) new long[]{new long[]{0, 0}, new long[]{j2, j3}});
        evaluationCurvePointArr[0] = new EvaluationCurvePoint(Criteria.INVALID_GAIN, Criteria.INVALID_GAIN, dArr[0]);
        evaluationCurvePointArr2[0] = new EvaluationCurvePoint(Criteria.INVALID_GAIN, evaluationCurvePointArr2[1].getY(), dArr[0]);
        evaluationCurvePointArr3[0] = new EvaluationCurvePoint(Criteria.INVALID_GAIN, Criteria.INVALID_GAIN, dArr[0]);
        evaluationCurvePointArr4[0] = new EvaluationCurvePoint(Criteria.INVALID_GAIN, Criteria.INVALID_GAIN, dArr[0]);
        return Tuple3.of(confusionMatrixArr, dArr, new EvaluationCurve[]{new EvaluationCurve(evaluationCurvePointArr), new EvaluationCurve(evaluationCurvePointArr2), new EvaluationCurve(evaluationCurvePointArr3), new EvaluationCurve(evaluationCurvePointArr4)});
    }

    static int getMiddleThresholdIndex(double[] dArr) {
        return getMiddleThresholdIndex(dArr, 0.5d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int getMiddleThresholdIndex(double[] dArr, double d) {
        double d2 = Double.MAX_VALUE;
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (Math.abs(dArr[i2] - d) < d2) {
                d2 = Math.abs(dArr[i2] - d);
                i = i2;
            }
        }
        return i;
    }

    static Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> sample(double d, Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> tuple3) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(0);
        ConfusionMatrix[] confusionMatrixArr = (ConfusionMatrix[]) tuple3.f0;
        double[] dArr = (double[]) tuple3.f1;
        EvaluationCurve[] evaluationCurveArr = (EvaluationCurve[]) tuple3.f2;
        double d2 = dArr[0];
        for (int i = 0; i < dArr.length; i++) {
            if (Math.abs(d2 - dArr[i]) >= d - PROBABILITY_ERROR || Math.abs(dArr[i] - 0.5d) < PROBABILITY_ERROR) {
                arrayList.add(Integer.valueOf(i));
                d2 = dArr[i];
            }
        }
        double[] dArr2 = new double[arrayList.size() - 1];
        ConfusionMatrix[] confusionMatrixArr2 = new ConfusionMatrix[arrayList.size() - 1];
        EvaluationCurvePoint[][] evaluationCurvePointArr = new EvaluationCurvePoint[evaluationCurveArr.length][arrayList.size()];
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            if (i2 > 0) {
                dArr2[i2 - 1] = dArr[((Integer) arrayList.get(i2)).intValue()];
                confusionMatrixArr2[i2 - 1] = confusionMatrixArr[((Integer) arrayList.get(i2)).intValue()];
            }
            for (int i3 = 0; i3 < evaluationCurveArr.length; i3++) {
                evaluationCurvePointArr[i3][i2] = evaluationCurveArr[i3].getPoints()[((Integer) arrayList.get(i2)).intValue()];
            }
        }
        EvaluationCurve[] evaluationCurveArr2 = new EvaluationCurve[evaluationCurveArr.length];
        for (int i4 = 0; i4 < evaluationCurveArr.length; i4++) {
            evaluationCurveArr2[i4] = new EvaluationCurve(evaluationCurvePointArr[i4]);
        }
        return Tuple3.of(confusionMatrixArr2, dArr2, evaluationCurveArr2);
    }
}
