package com.alibaba.alink.operator.common.evaluation;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.io.Serializable;
import java.util.Arrays;

/* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/ConfusionMatrix.class */
public class ConfusionMatrix implements Serializable {
    private static final long serialVersionUID = -363689060257724357L;
    LongMatrix longMatrix;
    int labelCnt;
    long total;
    private long[] actualLabelFrequency;
    private long[] predictLabelFrequency;
    private double tpCount;
    private double tnCount;
    private double fpCount;
    private double fnCount;

    public ConfusionMatrix(long[][] jArr) {
        this(new LongMatrix(jArr));
    }

    public ConfusionMatrix(LongMatrix longMatrix) {
        this.tpCount = Criteria.INVALID_GAIN;
        this.tnCount = Criteria.INVALID_GAIN;
        this.fpCount = Criteria.INVALID_GAIN;
        this.fnCount = Criteria.INVALID_GAIN;
        AkPreconditions.checkArgument(longMatrix.getRowNum() == longMatrix.getColNum(), "The row size must be equal to col size!");
        this.longMatrix = longMatrix;
        this.labelCnt = this.longMatrix.getRowNum();
        this.actualLabelFrequency = longMatrix.getColSums();
        this.predictLabelFrequency = longMatrix.getRowSums();
        this.total = longMatrix.getTotal();
        for (int i = 0; i < this.labelCnt; i++) {
            this.tnCount += numTrueNegative(Integer.valueOf(i));
            this.tpCount += numTruePositive(Integer.valueOf(i));
            this.fnCount += numFalseNegative(Integer.valueOf(i));
            this.fpCount += numFalsePositive(Integer.valueOf(i));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public long[] getActualLabelFrequency() {
        return this.actualLabelFrequency;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double[] getActualLabelProportion() {
        double[] dArr = new double[this.labelCnt];
        for (int i = 0; i < this.labelCnt; i++) {
            dArr[i] = this.actualLabelFrequency[i] / this.total;
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public long[] getPredictLabelFrequency() {
        return this.predictLabelFrequency;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double[] getPredictLabelProportion() {
        double[] dArr = new double[this.labelCnt];
        for (int i = 0; i < this.labelCnt; i++) {
            dArr[i] = this.predictLabelFrequency[i] / this.total;
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double getTotalKappa() {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.labelCnt; i++) {
            d2 += this.predictLabelFrequency[i] * this.actualLabelFrequency[i];
            d += this.longMatrix.getValue(i, i);
        }
        double d3 = d2 / (this.total * this.total);
        double d4 = d / this.total;
        if (d3 < 1.0d) {
            return (d4 - d3) / (1.0d - d3);
        }
        return 1.0d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double getTotalAccuracy() {
        double d = 0.0d;
        for (int i = 0; i < this.labelCnt; i++) {
            d += this.longMatrix.getValue(i, i);
        }
        return d / this.total;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double numTruePositive(Integer num) {
        AkPreconditions.checkArgument(null == num || num.intValue() < this.labelCnt, "labelIndex must be null or less than " + this.labelCnt);
        return null == num ? this.tpCount : this.longMatrix.getValue(num.intValue(), num.intValue());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double numTrueNegative(Integer num) {
        AkPreconditions.checkArgument(null == num || num.intValue() < this.labelCnt, "labelIndex must be null or less than " + this.labelCnt);
        return null == num ? this.tnCount : ((this.longMatrix.getValue(num.intValue(), num.intValue()) + this.total) - this.predictLabelFrequency[num.intValue()]) - this.actualLabelFrequency[num.intValue()];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double numFalsePositive(Integer num) {
        AkPreconditions.checkArgument(null == num || num.intValue() < this.labelCnt, "labelIndex must be null or less than " + this.labelCnt);
        return null == num ? this.fpCount : this.predictLabelFrequency[num.intValue()] - this.longMatrix.getValue(num.intValue(), num.intValue());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double numFalseNegative(Integer num) {
        AkPreconditions.checkArgument(null == num || num.intValue() < this.labelCnt, "labelIndex must be null or less than " + this.labelCnt);
        return null == num ? this.fnCount : this.actualLabelFrequency[num.intValue()] - this.longMatrix.getValue(num.intValue(), num.intValue());
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof ConfusionMatrix)) {
            return false;
        }
        ConfusionMatrix confusionMatrix = (ConfusionMatrix) obj;
        return this.longMatrix.equals(confusionMatrix.longMatrix) && this.labelCnt == confusionMatrix.labelCnt && this.total == confusionMatrix.total && Arrays.equals(this.actualLabelFrequency, confusionMatrix.actualLabelFrequency) && Arrays.equals(this.predictLabelFrequency, confusionMatrix.predictLabelFrequency);
    }
}
