package com.alibaba.alink.operator.common.evaluation;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.MTableUtil;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.operator.batch.evaluation.EvalMultiLabelBatchOp;
import com.alibaba.alink.operator.common.evaluation.ClassificationMetricComputers;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.evaluation.EvalMultiClassParams;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.operators.SortPartitionOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/ClassificationEvaluationUtil.class */
public class ClassificationEvaluationUtil implements Serializable {
    public static final String STATISTICS_OUTPUT = "Statistics";
    private static final long serialVersionUID = -2732226343798663348L;
    public static final Tuple2<String, Integer> WINDOW = Tuple2.of("window", 0);
    public static final Tuple2<String, Integer> ALL = Tuple2.of("all", 1);
    private static final Logger LOG = LoggerFactory.getLogger(ClassificationEvaluationUtil.class);
    public static String LABELS_BC_NAME = "labels";
    public static String DECISION_THRESHOLD_BC_NAME = "score_boundary";
    public static String PARTITION_SUMMARIES_BC_NAME = "partition_summaries";
    public static int DETAIL_BIN_NUMBER = 100000;
    public static int TOTAL_TRUE = 2;
    public static int TOTAL_FALSE = 3;
    public static int CUR_TRUE = 0;
    public static int CUR_FALSE = 1;
    private static int TPR = 0;
    private static int FPR = 1;
    private static int PRECISION = 2;
    private static int POSITIVE_RATE = 3;
    public static int RECORD_LEN = 4;
    private static double PROBABILITY_ERROR = 0.001d;
    public static int BINARY_LABEL_NUMBER = 2;

    /* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/ClassificationEvaluationUtil$BinaryPartitionSummary.class */
    public static class BinaryPartitionSummary implements Serializable {
        private static final long serialVersionUID = 1;
        Integer taskId;
        double maxScore;
        long curPositive;
        long curNegative;

        public BinaryPartitionSummary(Integer num, double d, long j, long j2) {
            this.taskId = num;
            this.maxScore = d;
            this.curPositive = j;
            this.curNegative = j2;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/ClassificationEvaluationUtil$CalcBinaryMetricsSummary.class */
    static class CalcBinaryMetricsSummary extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, BaseMetricsSummary> {
        private static final long serialVersionUID = 5680342197308160013L;
        private Object[] labels;
        private long[] countValues;
        private boolean firstBin;
        private double auc;
        private double decisionThreshold;
        private double largestThreshold;

        CalcBinaryMetricsSummary() {
        }

        public void open(Configuration configuration) throws Exception {
            List broadcastVariable = getRuntimeContext().getBroadcastVariable(EvalMultiLabelBatchOp.LABELS);
            AkPreconditions.checkState(broadcastVariable.size() > 0, (ExceptionWithErrorCode) new AkIllegalDataException("Please check the evaluation input! there is no effective row!"));
            this.labels = (Object[]) ((Tuple2) broadcastVariable.get(0)).f1;
            Tuple2<Boolean, long[]> reduceBinaryPartitionSummary = ClassificationEvaluationUtil.reduceBinaryPartitionSummary(getRuntimeContext().getBroadcastVariable(ClassificationEvaluationUtil.PARTITION_SUMMARIES_BC_NAME), getRuntimeContext().getIndexOfThisSubtask());
            this.firstBin = ((Boolean) reduceBinaryPartitionSummary.f0).booleanValue();
            this.countValues = (long[]) reduceBinaryPartitionSummary.f1;
            this.auc = ((Double) ((Tuple1) getRuntimeContext().getBroadcastVariable("auc").get(0)).f0).doubleValue();
            long j = this.countValues[ClassificationEvaluationUtil.TOTAL_TRUE];
            long j2 = this.countValues[ClassificationEvaluationUtil.TOTAL_FALSE];
            if (j == 0) {
                ClassificationEvaluationUtil.LOG.warn("There is no positive sample in data!");
            }
            if (j2 == 0) {
                ClassificationEvaluationUtil.LOG.warn("There is no negative sample in data!");
            }
            if (j <= 0 || j2 <= 0) {
                this.auc = Double.NaN;
            } else {
                this.auc = (this.auc - (((1.0d * j) * (j + 1)) / 2.0d)) / (j * j2);
            }
            if (getRuntimeContext().hasBroadcastVariable(ClassificationEvaluationUtil.DECISION_THRESHOLD_BC_NAME)) {
                this.decisionThreshold = ((Double) getRuntimeContext().getBroadcastVariable(ClassificationEvaluationUtil.DECISION_THRESHOLD_BC_NAME).get(0)).doubleValue();
                this.largestThreshold = Double.POSITIVE_INFINITY;
            } else {
                this.decisionThreshold = 0.5d;
                this.largestThreshold = 1.0d;
            }
        }

        public void mapPartition(Iterable<Tuple3<Double, Boolean, Double>> iterable, Collector<BaseMetricsSummary> collector) {
            AccurateBinaryMetricsSummary accurateBinaryMetricsSummary = new AccurateBinaryMetricsSummary(this.labels, this.decisionThreshold, Criteria.INVALID_GAIN, 0L, this.auc);
            double[] dArr = new double[ClassificationEvaluationUtil.RECORD_LEN];
            Iterator<Tuple3<Double, Boolean, Double>> it = iterable.iterator();
            while (it.hasNext()) {
                ClassificationEvaluationUtil.updateAccurateBinaryMetricsSummary(it.next(), accurateBinaryMetricsSummary, this.countValues, dArr, this.firstBin, this.decisionThreshold, this.largestThreshold);
            }
            collector.collect(accurateBinaryMetricsSummary);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/ClassificationEvaluationUtil$CalcBinaryPartitionSummary.class */
    static class CalcBinaryPartitionSummary extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, BinaryPartitionSummary> {
        private double decisionThreshold = 0.5d;
        private static final long serialVersionUID = 9012670438603117070L;

        CalcBinaryPartitionSummary() {
        }

        public void open(Configuration configuration) {
            this.decisionThreshold = getRuntimeContext().hasBroadcastVariable(ClassificationEvaluationUtil.DECISION_THRESHOLD_BC_NAME) ? ((Double) getRuntimeContext().getBroadcastVariable(ClassificationEvaluationUtil.DECISION_THRESHOLD_BC_NAME).get(0)).doubleValue() : 0.5d;
        }

        public void mapPartition(Iterable<Tuple3<Double, Boolean, Double>> iterable, Collector<BinaryPartitionSummary> collector) {
            BinaryPartitionSummary binaryPartitionSummary = new BinaryPartitionSummary(Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask()), -1.7976931348623157E308d, 0L, 0L);
            iterable.forEach(tuple3 -> {
                ClassificationEvaluationUtil.updateBinaryPartitionSummary(binaryPartitionSummary, tuple3, this.decisionThreshold);
            });
            collector.collect(binaryPartitionSummary);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/ClassificationEvaluationUtil$CalcSampleOrders.class */
    static class CalcSampleOrders extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, Tuple3<Double, Long, Boolean>> {
        private static final long serialVersionUID = 3047511137846831576L;
        private long startIndex;
        private long total;
        private double decisionThreshold;

        CalcSampleOrders() {
        }

        public void open(Configuration configuration) throws Exception {
            Tuple2<Boolean, long[]> reduceBinaryPartitionSummary = ClassificationEvaluationUtil.reduceBinaryPartitionSummary(getRuntimeContext().getBroadcastVariable(ClassificationEvaluationUtil.PARTITION_SUMMARIES_BC_NAME), getRuntimeContext().getIndexOfThisSubtask());
            this.startIndex = ((long[]) reduceBinaryPartitionSummary.f1)[ClassificationEvaluationUtil.CUR_FALSE] + ((long[]) reduceBinaryPartitionSummary.f1)[ClassificationEvaluationUtil.CUR_TRUE] + 1;
            this.total = ((long[]) reduceBinaryPartitionSummary.f1)[ClassificationEvaluationUtil.TOTAL_TRUE] + ((long[]) reduceBinaryPartitionSummary.f1)[ClassificationEvaluationUtil.TOTAL_FALSE];
            this.decisionThreshold = getRuntimeContext().hasBroadcastVariable(ClassificationEvaluationUtil.DECISION_THRESHOLD_BC_NAME) ? ((Double) getRuntimeContext().getBroadcastVariable(ClassificationEvaluationUtil.DECISION_THRESHOLD_BC_NAME).get(0)).doubleValue() : 0.5d;
        }

        public void mapPartition(Iterable<Tuple3<Double, Boolean, Double>> iterable, Collector<Tuple3<Double, Long, Boolean>> collector) throws Exception {
            for (Tuple3<Double, Boolean, Double> tuple3 : iterable) {
                if (!ClassificationEvaluationUtil.isMiddlePoint(tuple3, this.decisionThreshold)) {
                    collector.collect(Tuple3.of(tuple3.f0, Long.valueOf((this.total - this.startIndex) + 1), tuple3.f1));
                    this.startIndex++;
                }
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/ClassificationEvaluationUtil$Computations.class */
    enum Computations {
        TRUE_NEGATIVE(new ClassificationMetricComputers.TrueNegativeRate(), BaseSimpleClassifierMetrics.TRUE_NEGATIVE_RATE_ARRAY, BaseSimpleClassifierMetrics.WEIGHTED_TRUE_NEGATIVE_RATE, BaseSimpleClassifierMetrics.MACRO_TRUE_NEGATIVE_RATE, BaseSimpleClassifierMetrics.MICRO_TRUE_NEGATIVE_RATE),
        TRUE_POSITIVE(new ClassificationMetricComputers.TruePositiveRate(), BaseSimpleClassifierMetrics.TRUE_POSITIVE_RATE_ARRAY, BaseSimpleClassifierMetrics.WEIGHTED_TRUE_POSITIVE_RATE, BaseSimpleClassifierMetrics.MACRO_TRUE_POSITIVE_RATE, BaseSimpleClassifierMetrics.MICRO_TRUE_POSITIVE_RATE),
        FALSE_NEGATIVE(new ClassificationMetricComputers.FalseNegativeRate(), BaseSimpleClassifierMetrics.FALSE_NEGATIVE_RATE_ARRAY, BaseSimpleClassifierMetrics.WEIGHTED_FALSE_NEGATIVE_RATE, BaseSimpleClassifierMetrics.MACRO_FALSE_NEGATIVE_RATE, BaseSimpleClassifierMetrics.MICRO_FALSE_NEGATIVE_RATE),
        FALSE_POSITIVE(new ClassificationMetricComputers.FalsePositiveRate(), BaseSimpleClassifierMetrics.FALSE_POSITIVE_RATE_ARRAY, BaseSimpleClassifierMetrics.WEIGHTED_FALSE_POSITIVE_RATE, BaseSimpleClassifierMetrics.MACRO_FALSE_POSITIVE_RATE, BaseSimpleClassifierMetrics.MICRO_FALSE_POSITIVE_RATE),
        PRECISION(new ClassificationMetricComputers.Precision(), BaseSimpleClassifierMetrics.PRECISION_ARRAY, BaseSimpleClassifierMetrics.WEIGHTED_PRECISION, BaseSimpleClassifierMetrics.MACRO_PRECISION, BaseSimpleClassifierMetrics.MICRO_PRECISION),
        SPECITIVITY(new ClassificationMetricComputers.TrueNegativeRate(), BaseSimpleClassifierMetrics.SPECIFICITY_ARRAY, BaseSimpleClassifierMetrics.WEIGHTED_SPECIFICITY, BaseSimpleClassifierMetrics.MACRO_SPECIFICITY, BaseSimpleClassifierMetrics.MICRO_SPECIFICITY),
        SENSITIVITY(new ClassificationMetricComputers.TruePositiveRate(), BaseSimpleClassifierMetrics.SENSITIVITY_ARRAY, BaseSimpleClassifierMetrics.WEIGHTED_SENSITIVITY, BaseSimpleClassifierMetrics.MACRO_SENSITIVITY, BaseSimpleClassifierMetrics.MICRO_SENSITIVITY),
        RECALL(new ClassificationMetricComputers.TruePositiveRate(), BaseSimpleClassifierMetrics.RECALL_ARRAY, BaseSimpleClassifierMetrics.WEIGHTED_RECALL, BaseSimpleClassifierMetrics.MACRO_RECALL, BaseSimpleClassifierMetrics.MICRO_RECALL),
        F1(new ClassificationMetricComputers.F1(), BaseSimpleClassifierMetrics.F1_ARRAY, BaseSimpleClassifierMetrics.WEIGHTED_F1, BaseSimpleClassifierMetrics.MACRO_F1, BaseSimpleClassifierMetrics.MICRO_F1),
        ACCURACY(new ClassificationMetricComputers.Accuracy(), BaseSimpleClassifierMetrics.ACCURACY_ARRAY, BaseSimpleClassifierMetrics.WEIGHTED_ACCURACY, BaseSimpleClassifierMetrics.MACRO_ACCURACY, BaseSimpleClassifierMetrics.MICRO_ACCURACY),
        KAPPA(new ClassificationMetricComputers.Kappa(), BaseSimpleClassifierMetrics.KAPPA_ARRAY, BaseSimpleClassifierMetrics.WEIGHTED_KAPPA, BaseSimpleClassifierMetrics.MACRO_KAPPA, BaseSimpleClassifierMetrics.MICRO_KAPPA);

        ClassificationMetricComputers.BaseClassificationMetricComputer computer;
        ParamInfo<double[]> arrayParamInfo;
        ParamInfo<Double> weightedParamInfo;
        ParamInfo<Double> macroParamInfo;
        ParamInfo<Double> microParamInfo;

        Computations(ClassificationMetricComputers.BaseClassificationMetricComputer baseClassificationMetricComputer, ParamInfo paramInfo, ParamInfo paramInfo2, ParamInfo paramInfo3, ParamInfo paramInfo4) {
            this.computer = baseClassificationMetricComputer;
            this.arrayParamInfo = paramInfo;
            this.weightedParamInfo = paramInfo2;
            this.macroParamInfo = paramInfo3;
            this.microParamInfo = paramInfo4;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/ClassificationEvaluationUtil$SampleStatisticsMapPartitionFunction.class */
    public static class SampleStatisticsMapPartitionFunction extends RichMapPartitionFunction<Row, Tuple3<Double, Boolean, Double>> {
        private static final long serialVersionUID = 5680342197308160013L;
        private Tuple2<Map<Object, Integer>, Object[]> map;
        private final TypeInformation<?> labelType;
        private final LabelProbMapExtractor extractor;
        private double decisionThreshold = 0.5d;

        public SampleStatisticsMapPartitionFunction(TypeInformation<?> typeInformation, LabelProbMapExtractor labelProbMapExtractor) {
            this.labelType = typeInformation;
            this.extractor = labelProbMapExtractor;
        }

        public void open(Configuration configuration) throws Exception {
            List broadcastVariable = getRuntimeContext().getBroadcastVariable(EvalMultiLabelBatchOp.LABELS);
            AkPreconditions.checkState(broadcastVariable.size() > 0, (ExceptionWithErrorCode) new AkIllegalDataException("Please check the evaluation input! there is no effective row!"));
            this.map = (Tuple2) broadcastVariable.get(0);
            this.decisionThreshold = getRuntimeContext().hasBroadcastVariable(ClassificationEvaluationUtil.DECISION_THRESHOLD_BC_NAME) ? ((Double) getRuntimeContext().getBroadcastVariable(ClassificationEvaluationUtil.DECISION_THRESHOLD_BC_NAME).get(0)).doubleValue() : 0.5d;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple3<Double, Boolean, Double>> collector) {
            Iterator<Row> it = iterable.iterator();
            while (it.hasNext()) {
                Tuple3<Double, Boolean, Double> binaryDetailStatistics = ClassificationEvaluationUtil.getBinaryDetailStatistics(it.next(), (Object[]) this.map.f1, this.labelType, this.extractor);
                if (null != binaryDetailStatistics) {
                    collector.collect(binaryDetailStatistics);
                }
            }
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                collector.collect(Tuple3.of(Double.valueOf(this.decisionThreshold), true, Double.valueOf(Double.NaN)));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/ClassificationEvaluationUtil$Type.class */
    public enum Type {
        PRED_DETAIL,
        PRED_RESULT
    }

    public static Tuple3<Double, Boolean, Double> getBinaryDetailStatistics(Row row, Object[] objArr, TypeInformation<?> typeInformation) {
        return getBinaryDetailStatistics(row, objArr, typeInformation, new DefaultLabelProbMapExtractor());
    }

    public static Tuple3<Double, Boolean, Double> getBinaryDetailStatistics(Row row, Object[] objArr, TypeInformation<?> typeInformation, LabelProbMapExtractor labelProbMapExtractor) {
        AkPreconditions.checkArgument(objArr.length == 2, "Label length is not 2, Only support binary evaluation!");
        if (!EvaluationUtil.checkRowFieldNotNull(row)) {
            return null;
        }
        TreeMap<Object, Double> extractLabelProbMap = EvaluationUtil.extractLabelProbMap(row, typeInformation, labelProbMapExtractor);
        Object field = row.getField(0);
        AkPreconditions.checkState(extractLabelProbMap.size() == BINARY_LABEL_NUMBER, "The number of labels must be equal to 2!");
        double extractLogloss = EvaluationUtil.extractLogloss(extractLabelProbMap, field);
        double doubleValue = extractLabelProbMap.get(objArr[0]).doubleValue();
        if (field.equals(objArr[0])) {
            return Tuple3.of(Double.valueOf(doubleValue), true, Double.valueOf(extractLogloss));
        }
        if (field.equals(objArr[1])) {
            return Tuple3.of(Double.valueOf(doubleValue), false, Double.valueOf(extractLogloss));
        }
        return null;
    }

    public static DataSet<BaseMetricsSummary> calLabelPredDetailLocal(DataSet<Tuple2<Map<Object, Integer>, Object[]>> dataSet, DataSet<Tuple3<Double, Boolean, Double>> dataSet2, DataSet<Double> dataSet3) {
        SortPartitionOperator sortPartition = dataSet2.partitionByRange(new int[]{0}).sortPartition(0, Order.DESCENDING);
        SingleInputUdfOperator withBroadcastSet = sortPartition.mapPartition(new CalcBinaryPartitionSummary()).withBroadcastSet(dataSet3, DECISION_THRESHOLD_BC_NAME);
        return sortPartition.mapPartition(new CalcBinaryMetricsSummary()).withBroadcastSet(withBroadcastSet, PARTITION_SUMMARIES_BC_NAME).withBroadcastSet(dataSet, EvalMultiLabelBatchOp.LABELS).withBroadcastSet(sortPartition.mapPartition(new CalcSampleOrders()).withBroadcastSet(withBroadcastSet, PARTITION_SUMMARIES_BC_NAME).withBroadcastSet(dataSet3, DECISION_THRESHOLD_BC_NAME).groupBy(new int[]{0}).reduceGroup(new GroupReduceFunction<Tuple3<Double, Long, Boolean>, Tuple1<Double>>() { // from class: com.alibaba.alink.operator.common.evaluation.ClassificationEvaluationUtil.1
            private static final long serialVersionUID = -7442946470184046220L;

            public void reduce(Iterable<Tuple3<Double, Long, Boolean>> iterable, Collector<Tuple1<Double>> collector) {
                long j = 0;
                long j2 = 0;
                long j3 = 0;
                for (Tuple3<Double, Long, Boolean> tuple3 : iterable) {
                    j += ((Long) tuple3.f1).longValue();
                    j2++;
                    if (((Boolean) tuple3.f2).booleanValue()) {
                        j3++;
                    }
                }
                collector.collect(Tuple1.of(Double.valueOf(((1.0d * j) / j2) * j3)));
            }
        }).sum(0), "auc").withBroadcastSet(dataSet3, DECISION_THRESHOLD_BC_NAME);
    }

    public static Type judgeEvaluationType(Params params) {
        Type type;
        if (params.contains(EvalMultiClassParams.PREDICTION_DETAIL_COL)) {
            type = Type.PRED_DETAIL;
        } else {
            if (!params.contains(EvalMultiClassParams.PREDICTION_COL)) {
                throw new IllegalArgumentException("Error Input, must give either predictionCol or predictionDetailCol!");
            }
            type = Type.PRED_RESULT;
        }
        return type;
    }

    public static Tuple2<Map<Object, Integer>, Object[]> buildLabelIndexLabelArray(Set<Object> set, boolean z, String str, TypeInformation<?> typeInformation, boolean z2) {
        Object[] array = set.toArray();
        Arrays.sort(array, Collections.reverseOrder());
        AkPreconditions.checkArgument(!z || array.length == BINARY_LABEL_NUMBER, "The number of labels must be equal to 2!");
        HashMap hashMap = new HashMap(array.length);
        if (!z || null == str) {
            for (int i = 0; i < array.length; i++) {
                hashMap.put(array[i], Integer.valueOf(i));
            }
        } else {
            if (EvaluationUtil.labelCompare(array[1], str, typeInformation)) {
                Object obj = array[1];
                array[1] = array[0];
                array[0] = obj;
            } else if (!EvaluationUtil.labelCompare(array[0], str, typeInformation)) {
                throw new IllegalArgumentException("Not contain positiveValue");
            }
            hashMap.put(array[0], 0);
            hashMap.put(array[1], 1);
        }
        return Tuple2.of(hashMap, array);
    }

    private static double frequencyAvgValue(ClassificationMetricComputers.BaseClassificationMetricComputer baseClassificationMetricComputer, ConfusionMatrix confusionMatrix) {
        double d = 0.0d;
        double[] actualLabelProportion = confusionMatrix.getActualLabelProportion();
        for (int i = 0; i < confusionMatrix.labelCnt; i++) {
            d += baseClassificationMetricComputer.apply(confusionMatrix, Integer.valueOf(i)).doubleValue() * actualLabelProportion[i];
        }
        return d;
    }

    private static double macroAvgValue(ClassificationMetricComputers.BaseClassificationMetricComputer baseClassificationMetricComputer, ConfusionMatrix confusionMatrix) {
        double d = 0.0d;
        for (int i = 0; i < confusionMatrix.labelCnt; i++) {
            d += baseClassificationMetricComputer.apply(confusionMatrix, Integer.valueOf(i)).doubleValue();
        }
        return d / confusionMatrix.labelCnt;
    }

    private static double microAvgValue(ClassificationMetricComputers.BaseClassificationMetricComputer baseClassificationMetricComputer, ConfusionMatrix confusionMatrix) {
        return baseClassificationMetricComputer.apply(confusionMatrix, null).doubleValue();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double[] getAllValues(ClassificationMetricComputers.BaseClassificationMetricComputer baseClassificationMetricComputer, ConfusionMatrix confusionMatrix) {
        double[] dArr = new double[confusionMatrix.labelCnt + 3];
        for (int i = 0; i < confusionMatrix.labelCnt; i++) {
            dArr[i] = baseClassificationMetricComputer.apply(confusionMatrix, Integer.valueOf(i)).doubleValue();
        }
        dArr[confusionMatrix.labelCnt] = frequencyAvgValue(baseClassificationMetricComputer, confusionMatrix);
        dArr[confusionMatrix.labelCnt + 1] = macroAvgValue(baseClassificationMetricComputer, confusionMatrix);
        dArr[confusionMatrix.labelCnt + 2] = microAvgValue(baseClassificationMetricComputer, confusionMatrix);
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void setLoglossParams(Params params, double d, long j) {
        if (d >= Criteria.INVALID_GAIN) {
            params.set((ParamInfo<ParamInfo<Double>>) BaseSimpleClassifierMetrics.LOG_LOSS, (ParamInfo<Double>) Double.valueOf(d / j));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void setClassificationCommonParams(Params params, ConfusionMatrix confusionMatrix, String[] strArr) {
        params.set((ParamInfo<ParamInfo<String[]>>) BaseSimpleClassifierMetrics.LABEL_ARRAY, (ParamInfo<String[]>) strArr);
        params.set((ParamInfo<ParamInfo<long[]>>) BaseSimpleClassifierMetrics.ACTUAL_LABEL_FREQUENCY, (ParamInfo<long[]>) confusionMatrix.getActualLabelFrequency());
        params.set((ParamInfo<ParamInfo<double[]>>) BaseSimpleClassifierMetrics.ACTUAL_LABEL_PROPORTION, (ParamInfo<double[]>) confusionMatrix.getActualLabelProportion());
        params.set((ParamInfo<ParamInfo<long[][]>>) BaseSimpleClassifierMetrics.CONFUSION_MATRIX, (ParamInfo<long[][]>) confusionMatrix.longMatrix.getMatrix());
        params.set((ParamInfo<ParamInfo<Long>>) BaseSimpleClassifierMetrics.TOTAL_SAMPLES, (ParamInfo<Long>) Long.valueOf(confusionMatrix.total));
        for (Computations computations : Computations.values()) {
            params.set((ParamInfo<ParamInfo<Double>>) computations.weightedParamInfo, (ParamInfo<Double>) Double.valueOf(frequencyAvgValue(computations.computer, confusionMatrix)));
            params.set((ParamInfo<ParamInfo<Double>>) computations.macroParamInfo, (ParamInfo<Double>) Double.valueOf(macroAvgValue(computations.computer, confusionMatrix)));
            params.set((ParamInfo<ParamInfo<Double>>) computations.microParamInfo, (ParamInfo<Double>) Double.valueOf(microAvgValue(computations.computer, confusionMatrix)));
        }
        params.set((ParamInfo<ParamInfo<Double>>) BaseSimpleClassifierMetrics.ACCURACY, (ParamInfo<Double>) Double.valueOf(confusionMatrix.getTotalAccuracy()));
        params.set((ParamInfo<ParamInfo<Double>>) BaseSimpleClassifierMetrics.KAPPA, (ParamInfo<Double>) Double.valueOf(confusionMatrix.getTotalKappa()));
    }

    public static Tuple2<Boolean, long[]> reduceBinaryPartitionSummary(List<BinaryPartitionSummary> list, int i) {
        ArrayList<BinaryPartitionSummary> arrayList = new ArrayList(list);
        arrayList.sort(Comparator.comparingDouble(binaryPartitionSummary -> {
            return -binaryPartitionSummary.maxScore;
        }));
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        long j4 = 0;
        boolean z = true;
        for (BinaryPartitionSummary binaryPartitionSummary2 : arrayList) {
            if (binaryPartitionSummary2.taskId.intValue() == i) {
                z = j3 + j4 == 0;
                j2 = j4;
                j = j3;
            }
            j3 += binaryPartitionSummary2.curPositive;
            j4 += binaryPartitionSummary2.curNegative;
        }
        return Tuple2.of(Boolean.valueOf(z), new long[]{j, j2, j3, j4});
    }

    public static boolean isMiddlePoint(Tuple3<Double, Boolean, Double> tuple3, double d) {
        return Double.compare(((Double) tuple3.f0).doubleValue(), d) == 0 && ((Boolean) tuple3.f1).booleanValue() && Double.isNaN(((Double) tuple3.f2).doubleValue());
    }

    public static void updateBinaryPartitionSummary(BinaryPartitionSummary binaryPartitionSummary, Tuple3<Double, Boolean, Double> tuple3, double d) {
        if (!isMiddlePoint(tuple3, d)) {
            if (((Boolean) tuple3.f1).booleanValue()) {
                binaryPartitionSummary.curPositive++;
            } else {
                binaryPartitionSummary.curNegative++;
            }
        }
        if (Double.compare(binaryPartitionSummary.maxScore, ((Double) tuple3.f0).doubleValue()) < 0) {
            binaryPartitionSummary.maxScore = ((Double) tuple3.f0).doubleValue();
        }
    }

    /* JADX WARN: Type inference failed for: r2v31, types: [long[], long[][]] */
    /* JADX WARN: Type inference failed for: r2v40, types: [long[], long[][]] */
    public static void updateAccurateBinaryMetricsSummary(Tuple3<Double, Boolean, Double> tuple3, AccurateBinaryMetricsSummary accurateBinaryMetricsSummary, long[] jArr, double[] dArr, boolean z, double d, double d2) {
        if (accurateBinaryMetricsSummary.total == 0) {
            dArr[TPR] = jArr[TOTAL_TRUE] == 0 ? 1.0d : (1.0d * jArr[CUR_TRUE]) / jArr[TOTAL_TRUE];
            dArr[FPR] = jArr[TOTAL_FALSE] == 0 ? 1.0d : (1.0d * jArr[CUR_FALSE]) / jArr[TOTAL_FALSE];
            dArr[PRECISION] = jArr[CUR_TRUE] + jArr[CUR_FALSE] == 0 ? 1.0d : (1.0d * jArr[CUR_TRUE]) / (jArr[CUR_TRUE] + jArr[CUR_FALSE]);
            dArr[POSITIVE_RATE] = (1.0d * (jArr[CUR_TRUE] + jArr[CUR_FALSE])) / (jArr[TOTAL_TRUE] + jArr[TOTAL_FALSE]);
        }
        if (!isMiddlePoint(tuple3, d)) {
            accurateBinaryMetricsSummary.total++;
            accurateBinaryMetricsSummary.logLoss += ((Double) tuple3.f2).doubleValue();
            if (((Boolean) tuple3.f1).booleanValue()) {
                int i = CUR_TRUE;
                jArr[i] = jArr[i] + 1;
            } else {
                int i2 = CUR_FALSE;
                jArr[i2] = jArr[i2] + 1;
            }
        }
        double doubleValue = ((Double) tuple3.f0).doubleValue();
        double d3 = jArr[TOTAL_TRUE] == 0 ? 1.0d : (1.0d * jArr[CUR_TRUE]) / jArr[TOTAL_TRUE];
        double d4 = jArr[TOTAL_FALSE] == 0 ? 1.0d : (1.0d * jArr[CUR_FALSE]) / jArr[TOTAL_FALSE];
        double d5 = jArr[CUR_TRUE] + jArr[CUR_FALSE] == 0 ? 1.0d : (1.0d * jArr[CUR_TRUE]) / (jArr[CUR_TRUE] + jArr[CUR_FALSE]);
        double d6 = (1.0d * (jArr[CUR_TRUE] + jArr[CUR_FALSE])) / (jArr[TOTAL_TRUE] + jArr[TOTAL_FALSE]);
        List<Tuple2<Double, ConfusionMatrix>> list = accurateBinaryMetricsSummary.metricsInfoList;
        if (accurateBinaryMetricsSummary.total == 1 && z) {
            dArr[PRECISION] = d5;
            list.add(Tuple2.of(Double.valueOf(d2), new ConfusionMatrix((long[][]) new long[]{new long[]{0, 0}, new long[]{jArr[TOTAL_TRUE], jArr[TOTAL_FALSE]}})));
        }
        accurateBinaryMetricsSummary.gini += ((d6 - dArr[POSITIVE_RATE]) * (d3 + dArr[TPR])) / 2.0d;
        accurateBinaryMetricsSummary.prc += ((d3 - dArr[TPR]) * (d5 + dArr[PRECISION])) / 2.0d;
        accurateBinaryMetricsSummary.ks = Math.max(Math.abs(d4 - d3), accurateBinaryMetricsSummary.ks);
        dArr[TPR] = d3;
        dArr[FPR] = d4;
        dArr[PRECISION] = d5;
        dArr[POSITIVE_RATE] = d6;
        ConfusionMatrix confusionMatrix = new ConfusionMatrix((long[][]) new long[]{new long[]{jArr[CUR_TRUE], jArr[CUR_FALSE]}, new long[]{jArr[TOTAL_TRUE] - jArr[CUR_TRUE], jArr[TOTAL_FALSE] - jArr[CUR_FALSE]}});
        if (list.isEmpty() || ((isMiddlePoint(tuple3, d) && (list.isEmpty() || ((Double) list.get(list.size() - 1).f0).doubleValue() != d)) || Math.abs(doubleValue - ((Double) list.get(list.size() - 1).f0).doubleValue()) >= PROBABILITY_ERROR)) {
            list.add(Tuple2.of(Double.valueOf(doubleValue), confusionMatrix));
        } else {
            list.get(list.size() - 1).f1 = confusionMatrix;
        }
    }

    public static List<Tuple3<Double, Boolean, Double>> calcSampleStatistics(List<Row> list, Tuple2<Map<Object, Integer>, Object[]> tuple2, TypeInformation<?> typeInformation) {
        return calcSampleStatistics(list, tuple2, typeInformation, Double.valueOf(0.5d), new DefaultLabelProbMapExtractor());
    }

    public static List<Tuple3<Double, Boolean, Double>> calcSampleStatistics(List<Row> list, Tuple2<Map<Object, Integer>, Object[]> tuple2, TypeInformation<?> typeInformation, Double d, LabelProbMapExtractor labelProbMapExtractor) {
        ArrayList arrayList = new ArrayList();
        Iterator<Row> it = list.iterator();
        while (it.hasNext()) {
            Tuple3<Double, Boolean, Double> binaryDetailStatistics = getBinaryDetailStatistics(it.next(), (Object[]) tuple2.f1, typeInformation, labelProbMapExtractor);
            if (null != binaryDetailStatistics) {
                arrayList.add(binaryDetailStatistics);
            }
        }
        arrayList.add(Tuple3.of(d, true, Double.valueOf(Double.NaN)));
        return arrayList;
    }

    public static AccurateBinaryMetricsSummary calLabelPredDetailLocal(Tuple2<Map<Object, Integer>, Object[]> tuple2, List<Tuple3<Double, Boolean, Double>> list, Double d) {
        BinaryPartitionSummary binaryPartitionSummary = new BinaryPartitionSummary(0, -1.7976931348623157E308d, 0L, 0L);
        list.forEach(tuple3 -> {
            updateBinaryPartitionSummary(binaryPartitionSummary, tuple3, d.doubleValue());
        });
        list.sort(new Comparator<Tuple3<Double, Boolean, Double>>() { // from class: com.alibaba.alink.operator.common.evaluation.ClassificationEvaluationUtil.2
            @Override // java.util.Comparator
            public int compare(Tuple3<Double, Boolean, Double> tuple32, Tuple3<Double, Boolean, Double> tuple33) {
                return -((Double) tuple32.f0).compareTo((Double) tuple33.f0);
            }
        });
        ArrayList arrayList = new ArrayList();
        arrayList.add(binaryPartitionSummary);
        Tuple2<Boolean, long[]> reduceBinaryPartitionSummary = reduceBinaryPartitionSummary(arrayList, 0);
        long j = ((long[]) reduceBinaryPartitionSummary.f1)[CUR_FALSE] + ((long[]) reduceBinaryPartitionSummary.f1)[CUR_TRUE] + 1;
        long j2 = ((long[]) reduceBinaryPartitionSummary.f1)[TOTAL_TRUE] + ((long[]) reduceBinaryPartitionSummary.f1)[TOTAL_FALSE];
        ArrayList arrayList2 = new ArrayList();
        for (Tuple3<Double, Boolean, Double> tuple32 : list) {
            if (!isMiddlePoint(tuple32, d.doubleValue())) {
                arrayList2.add(Row.of(new Object[]{tuple32.f0, Long.valueOf((j2 - j) + 1), tuple32.f1}));
                j++;
            }
        }
        double d2 = 0.0d;
        Iterator<Row> it = MTableUtil.groupFunc(new MTable(arrayList2, "f0 double, f1 long, f2 boolean"), new String[]{"f0"}, new MTableUtil.GroupFunction() { // from class: com.alibaba.alink.operator.common.evaluation.ClassificationEvaluationUtil.3
            @Override // com.alibaba.alink.common.MTableUtil.GroupFunction
            public void calc(List<Row> list2, Collector<Row> collector) {
                long j3 = 0;
                long j4 = 0;
                long j5 = 0;
                for (Row row : list2) {
                    j3 += ((Long) row.getField(1)).longValue();
                    j4++;
                    if (((Boolean) row.getField(2)).booleanValue()) {
                        j5++;
                    }
                }
                collector.collect(Row.of(new Object[]{Double.valueOf(((1.0d * j3) / j4) * j5)}));
            }
        }).iterator();
        while (it.hasNext()) {
            d2 += ((Double) it.next().getField(0)).doubleValue();
        }
        boolean booleanValue = ((Boolean) reduceBinaryPartitionSummary.f0).booleanValue();
        long[] jArr = (long[]) reduceBinaryPartitionSummary.f1;
        long j3 = jArr[TOTAL_TRUE];
        long j4 = jArr[TOTAL_FALSE];
        if (j3 == 0) {
            LOG.warn("There is no positive sample in data!");
        }
        if (j4 == 0) {
            LOG.warn("There is no negative sample in data!");
        }
        AccurateBinaryMetricsSummary accurateBinaryMetricsSummary = new AccurateBinaryMetricsSummary((Object[]) tuple2.f1, d.doubleValue(), Criteria.INVALID_GAIN, 0L, (j3 <= 0 || j4 <= 0) ? Double.NaN : (d2 - (((1.0d * j3) * (j3 + 1)) / 2.0d)) / (j3 * j4));
        double[] dArr = new double[RECORD_LEN];
        Iterator<Tuple3<Double, Boolean, Double>> it2 = list.iterator();
        while (it2.hasNext()) {
            updateAccurateBinaryMetricsSummary(it2.next(), accurateBinaryMetricsSummary, jArr, dArr, booleanValue, d.doubleValue(), Double.POSITIVE_INFINITY);
        }
        return accurateBinaryMetricsSummary;
    }

    public static DataSet<Tuple3<Double, Boolean, Double>> calcSampleStatistics(DataSet<Row> dataSet, DataSet<Tuple2<Map<Object, Integer>, Object[]>> dataSet2, TypeInformation<?> typeInformation, DataSet<Double> dataSet3, LabelProbMapExtractor labelProbMapExtractor) {
        return dataSet.rebalance().mapPartition(new SampleStatisticsMapPartitionFunction(typeInformation, labelProbMapExtractor)).withBroadcastSet(dataSet2, LABELS_BC_NAME).withBroadcastSet(dataSet3, DECISION_THRESHOLD_BC_NAME);
    }

    public static DataSet<Tuple3<Double, Boolean, Double>> calcSampleStatistics(DataSet<Row> dataSet, DataSet<Tuple2<Map<Object, Integer>, Object[]>> dataSet2, TypeInformation<?> typeInformation) {
        return calcSampleStatistics(dataSet, dataSet2, typeInformation, (DataSet<Double>) dataSet2.getExecutionEnvironment().fromElements(new Double[]{Double.valueOf(0.5d)}), new DefaultLabelProbMapExtractor());
    }
}
