package com.alibaba.alink.operator.common.evaluation;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.operator.common.evaluation.ClassificationEvaluationUtil;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/MultiMetricsSummary.class */
public final class MultiMetricsSummary implements BaseMetricsSummary<MultiClassMetrics, MultiMetricsSummary> {
    private static final long serialVersionUID = -8742985165888894890L;
    LongMatrix matrix;
    Object[] labels;
    long total;
    double logLoss;

    public MultiMetricsSummary(long[][] jArr, Object[] objArr, double d, long j) {
        AkPreconditions.checkArgument(jArr.length > 0 && jArr.length == jArr[0].length, "The row size must be equal to col size!");
        this.matrix = new LongMatrix(jArr);
        this.labels = objArr;
        this.logLoss = d;
        this.total = j;
    }

    private void mergeConfusionMatrix(LongMatrix longMatrix, Object[] objArr, LongMatrix longMatrix2, Map<Object, Integer> map) {
        for (int i = 0; i < longMatrix.getRowNum(); i++) {
            int intValue = map.get(objArr[i]).intValue();
            for (int i2 = 0; i2 < longMatrix.getColNum(); i2++) {
                int intValue2 = map.get(objArr[i2]).intValue();
                longMatrix2.setValue(intValue, intValue2, longMatrix2.getValue(intValue, intValue2) + longMatrix.getValue(i, i2));
            }
        }
    }

    @Override // com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary
    public MultiMetricsSummary merge(MultiMetricsSummary multiMetricsSummary) {
        if (null == multiMetricsSummary) {
            return this;
        }
        if (Arrays.equals(this.labels, multiMetricsSummary.labels)) {
            this.matrix.plusEqual(multiMetricsSummary.matrix);
        } else {
            HashSet hashSet = new HashSet();
            hashSet.addAll(Arrays.asList(this.labels));
            hashSet.addAll(Arrays.asList(multiMetricsSummary.labels));
            Object[] array = hashSet.toArray();
            Arrays.sort(array, Collections.reverseOrder());
            int length = array.length;
            Map<Object, Integer> map = (Map) IntStream.range(0, length).boxed().collect(Collectors.toMap(num -> {
                return array[num.intValue()];
            }, num2 -> {
                return num2;
            }));
            LongMatrix longMatrix = new LongMatrix(new long[length][length]);
            mergeConfusionMatrix(this.matrix, this.labels, longMatrix, map);
            mergeConfusionMatrix(multiMetricsSummary.matrix, multiMetricsSummary.labels, longMatrix, map);
            this.labels = array;
            this.matrix = longMatrix;
        }
        this.logLoss += multiMetricsSummary.logLoss;
        this.total += multiMetricsSummary.total;
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary
    public MultiClassMetrics toMetrics() {
        String[] strArr = new String[this.labels.length];
        for (int i = 0; i < this.labels.length; i++) {
            strArr[i] = this.labels[i].toString();
        }
        Params params = new Params();
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(this.matrix);
        params.set((ParamInfo<ParamInfo<long[]>>) MultiClassMetrics.PREDICT_LABEL_FREQUENCY, (ParamInfo<long[]>) confusionMatrix.getPredictLabelFrequency());
        params.set((ParamInfo<ParamInfo<double[]>>) MultiClassMetrics.PREDICT_LABEL_PROPORTION, (ParamInfo<double[]>) confusionMatrix.getPredictLabelProportion());
        for (ClassificationEvaluationUtil.Computations computations : ClassificationEvaluationUtil.Computations.values()) {
            params.set((ParamInfo<ParamInfo<double[]>>) computations.arrayParamInfo, (ParamInfo<double[]>) ClassificationEvaluationUtil.getAllValues(computations.computer, confusionMatrix));
        }
        ClassificationEvaluationUtil.setClassificationCommonParams(params, confusionMatrix, strArr);
        ClassificationEvaluationUtil.setLoglossParams(params, this.logLoss, this.total);
        return new MultiClassMetrics(params);
    }
}
