package com.alibaba.alink.operator.common.evaluation;

import com.alibaba.alink.common.exceptions.AkIllegalOperationException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.io.filesystem.copy.csv.CsvInputFormat;
import com.alibaba.alink.operator.common.evaluation.BaseBinaryClassMetrics;
import com.alibaba.alink.operator.common.feature.binning.FeatureBinsUtil;
import com.alibaba.alink.operator.common.outlier.TimeSeriesAnomsUtils;
import com.alibaba.alink.operator.common.tree.viz.TreeModelViz;
import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils;
import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Font;
import java.awt.Paint;
import java.io.File;
import java.io.IOException;
import javax.imageio.ImageIO;
import javax.imageio.stream.FileImageOutputStream;
import org.apache.commons.lang.ArrayUtils;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.jfree.chart.ChartColor;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartRenderingInfo;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.title.TextTitle;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.jfree.ui.HorizontalAlignment;
import org.jfree.ui.RectangleEdge;

/* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/BaseBinaryClassMetrics.class */
public class BaseBinaryClassMetrics<T extends BaseBinaryClassMetrics<T>> extends BaseSimpleClassifierMetrics<T> {
    private static final long serialVersionUID = -4566069100819806462L;
    static final ParamInfo<double[][]> ROC_CURVE = ParamInfoFactory.createParamInfo("RocCurve", double[][].class).setDescription("auc").setRequired().build();
    public static final ParamInfo<Double> AUC = ParamInfoFactory.createParamInfo("AUC", Double.class).setDescription("auc").setRequired().build();
    public static final ParamInfo<Double> GINI = ParamInfoFactory.createParamInfo("GINI", Double.class).setDescription("GINI").setRequired().build();
    public static final ParamInfo<Double> KS = ParamInfoFactory.createParamInfo("K-S", Double.class).setDescription("ks").setRequired().build();
    public static final ParamInfo<Double> PRC = ParamInfoFactory.createParamInfo("PRC", Double.class).setDescription("ks").setRequired().build();
    static final ParamInfo<double[][]> PRECISION_RECALL_CURVE = ParamInfoFactory.createParamInfo("PrecisionRecallCurve", double[][].class).setDescription("recall precision curve").setRequired().build();
    static final ParamInfo<double[][]> LIFT_CHART = ParamInfoFactory.createParamInfo("LiftChart", double[][].class).setDescription("liftchart").setRequired().build();
    public static final ParamInfo<double[][]> LORENZ_CURVE = ParamInfoFactory.createParamInfo("LorenzCurve", double[][].class).setDescription("lorenzCurve").setRequired().build();
    static final ParamInfo<double[]> THRESHOLD_ARRAY = ParamInfoFactory.createParamInfo("ThresholdArray", double[].class).setDescription("threshold list").setRequired().build();
    static final ParamInfo<Double> PRECISION = ParamInfoFactory.createParamInfo("Precision", Double.class).setDescription("precision").setRequired().build();
    static final ParamInfo<Double> RECALL = ParamInfoFactory.createParamInfo("Recall", Double.class).setDescription("recall").setRequired().build();
    static final ParamInfo<Double> F1 = ParamInfoFactory.createParamInfo("F1", Double.class).setDescription("f1").setRequired().build();

    @Override // com.alibaba.alink.operator.common.evaluation.BaseMetrics
    public String toString() {
        StringBuilder sb = new StringBuilder(PrettyDisplayUtils.displayHeadline("Metrics:", '-'));
        String[] labelArray = getLabelArray();
        String[][] strArr = new String[2][2];
        long[][] confusionMatrix = getConfusionMatrix();
        strArr[0][0] = String.valueOf(confusionMatrix[0][0]);
        strArr[0][1] = String.valueOf(confusionMatrix[0][1]);
        strArr[1][0] = String.valueOf(confusionMatrix[1][0]);
        strArr[1][1] = String.valueOf(confusionMatrix[1][1]);
        sb.append("Auc:").append(PrettyDisplayUtils.display(getAuc())).append("\t").append("Accuracy:").append(PrettyDisplayUtils.display(Double.valueOf(getAccuracy()))).append("\t").append("Precision:").append(PrettyDisplayUtils.display(getPrecision())).append("\t").append("Recall:").append(PrettyDisplayUtils.display(getRecall())).append("\t").append("F1:").append(PrettyDisplayUtils.display(getF1())).append("\t").append("LogLoss:").append(PrettyDisplayUtils.display(getLogLoss())).append(CsvInputFormat.DEFAULT_LINE_DELIMITER).append(PrettyDisplayUtils.displayTable(strArr, 2, 2, labelArray, labelArray, "Pred\\Real"));
        return sb.toString();
    }

    public BaseBinaryClassMetrics(Row row) {
        super(row);
    }

    public BaseBinaryClassMetrics(Params params) {
        super(params);
    }

    public Tuple2<double[], double[]> getRocCurve() {
        double[][] dArr = (double[][]) getParams().get(ROC_CURVE);
        return Tuple2.of(dArr[0], dArr[1]);
    }

    public Tuple2<double[], double[]> getLorenzeCurve() {
        double[][] dArr = (double[][]) getParams().get(LORENZ_CURVE);
        return Tuple2.of(dArr[0], dArr[1]);
    }

    public Double getPrecision() {
        return (Double) get(PRECISION);
    }

    public Double getRecall() {
        return (Double) get(RECALL);
    }

    public Double getF1() {
        return (Double) get(F1);
    }

    public Double getAuc() {
        return (Double) get(AUC);
    }

    public Double getGini() {
        return (Double) get(GINI);
    }

    public Double getKs() {
        return (Double) get(KS);
    }

    public Double getPrc() {
        return (Double) get(PRC);
    }

    public Tuple2<double[], double[]> getPrecisionRecallCurve() {
        double[][] dArr = (double[][]) getParams().get(PRECISION_RECALL_CURVE);
        return Tuple2.of(dArr[0], dArr[1]);
    }

    public Tuple2<double[], double[]> getLiftChart() {
        double[][] dArr = (double[][]) getParams().get(LIFT_CHART);
        return Tuple2.of(dArr[0], dArr[1]);
    }

    public double[] getThresholdArray() {
        return (double[]) get(THRESHOLD_ARRAY);
    }

    public Tuple2<double[], double[]> getPrecisionByThreshold() {
        return Tuple2.of(getParams().get(THRESHOLD_ARRAY), getParams().get(PRECISION_ARRAY));
    }

    public Tuple2<double[], double[]> getSpecificityByThreshold() {
        return Tuple2.of(getParams().get(THRESHOLD_ARRAY), getParams().get(SPECIFICITY_ARRAY));
    }

    public Tuple2<double[], double[]> getSensitivityByThreshold() {
        return Tuple2.of(getParams().get(THRESHOLD_ARRAY), getParams().get(SENSITIVITY_ARRAY));
    }

    public Tuple2<double[], double[]> getRecallByThreshold() {
        return Tuple2.of(getParams().get(THRESHOLD_ARRAY), getParams().get(RECALL_ARRAY));
    }

    public Tuple2<double[], double[]> getF1ByThreshold() {
        return Tuple2.of(getParams().get(THRESHOLD_ARRAY), getParams().get(F1_ARRAY));
    }

    public Tuple2<double[], double[]> getAccuracyByThreshold() {
        return Tuple2.of(getParams().get(THRESHOLD_ARRAY), getParams().get(ACCURACY_ARRAY));
    }

    public Tuple2<double[], double[]> getKappaByThreshold() {
        return Tuple2.of(getParams().get(THRESHOLD_ARRAY), getParams().get(KAPPA_ARRAY));
    }

    static void saveAsImage(String str, boolean z, String str2, String str3, String str4, String[] strArr, Tuple2<String, Double> tuple2, Tuple2<double[], double[]>... tuple2Arr) throws IOException {
        File file = new File(str);
        if (!z && file.exists()) {
            throw new AkIllegalOperationException(String.format("File %s exists and isOverwrite is set to false.", str));
        }
        AkPreconditions.checkNotNull(tuple2Arr, "Points should not be null!");
        XYSeriesCollection xYSeriesCollection = new XYSeriesCollection();
        for (int i = 0; i < tuple2Arr.length; i++) {
            XYSeries xYSeries = new XYSeries(strArr[i]);
            for (int i2 = 0; i2 < ((double[]) tuple2Arr[i].f0).length; i2++) {
                xYSeries.add(((double[]) tuple2Arr[i].f0)[i2], ((double[]) tuple2Arr[i].f1)[i2]);
            }
            xYSeriesCollection.addSeries(xYSeries);
        }
        JFreeChart createXYLineChart = ChartFactory.createXYLineChart(str2, str3, str4, xYSeriesCollection, PlotOrientation.VERTICAL, strArr.length > 1, true, false);
        createXYLineChart.setBackgroundPaint(Color.white);
        XYPlot xYPlot = createXYLineChart.getXYPlot();
        xYPlot.setBackgroundPaint(ChartColor.WHITE);
        xYPlot.setRangeGridlinePaint(ChartColor.BLACK);
        xYPlot.setDomainGridlinePaint(ChartColor.BLACK);
        xYPlot.setOutlinePaint((Paint) null);
        xYPlot.getRenderer().setSeriesPaint(0, ChartColor.BLACK);
        for (int i3 = 0; i3 < strArr.length; i3++) {
            xYPlot.getRenderer().setSeriesStroke(i3, new BasicStroke(2.0f));
        }
        if (null != tuple2) {
            TextTitle textTitle = new TextTitle(((String) tuple2.f0) + TimeSeriesAnomsUtils.VAL_DELIMITER + FeatureBinsUtil.keepGivenDecimal((Double) tuple2.f1, 3));
            textTitle.setFont(new Font("SansSerif", 0, 15));
            textTitle.setPosition(RectangleEdge.TOP);
            textTitle.setHorizontalAlignment(HorizontalAlignment.LEFT);
            createXYLineChart.addSubtitle(textTitle);
        }
        ImageIO.write(createXYLineChart.createBufferedImage(500, 400, 1, (ChartRenderingInfo) null), TreeModelViz.getFormat(str), new FileImageOutputStream(file));
    }

    public void saveRocCurveAsImage(String str, boolean z) throws IOException {
        saveAsImage(str, z, "ROC Curve", "FPR", "TPR", new String[]{"ROC"}, Tuple2.of("AUC", getAuc()), getRocCurve());
    }

    public void saveKSAsImage(String str, boolean z) throws IOException {
        double[] thresholdArray = getThresholdArray();
        double[] dArr = (double[]) getRocCurve().f1;
        double[] dArr2 = (double[]) getRocCurve().f0;
        ArrayUtils.reverse(thresholdArray);
        ArrayUtils.reverse(dArr);
        ArrayUtils.reverse(dArr2);
        saveAsImage(str, z, "K-S Curve", "Thresholds", "Rate", new String[]{"TPR", "FPR"}, Tuple2.of("KS", getKs()), Tuple2.of(thresholdArray, dArr), Tuple2.of(thresholdArray, dArr2));
    }

    public void saveLiftChartAsImage(String str, boolean z) throws IOException {
        saveAsImage(str, z, "LiftChart", "Positive Rate", "True Positive", new String[]{"LiftChart"}, null, getLiftChart());
    }

    public void savePrecisionRecallCurveAsImage(String str, boolean z) throws IOException {
        saveAsImage(str, z, "PrecisionRecallCurve", "Recall", "Precision", new String[]{"PrecisionRecall"}, Tuple2.of("PRC", getPrc()), getPrecisionRecallCurve());
    }

    public void saveLorenzCurveAsImage(String str, boolean z) throws IOException {
        saveAsImage(str, z, "LorenzCurve", "Positive Rate", "TPR", new String[]{"LorenzCurve"}, Tuple2.of("GINI", getGini()), getLorenzeCurve());
    }
}
