package com.alibaba.alink.operator.common.evaluation;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.distance.ContinuousDistance;
import com.alibaba.alink.operator.common.distance.CosineDistance;
import com.alibaba.alink.operator.common.distance.EuclideanDistance;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/ClusterEvaluationUtil.class */
public class ClusterEvaluationUtil {
    private static final long serialVersionUID = -7300130718897249710L;
    public static int COUNT = 0;
    public static int MEAN = 1;

    /* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/ClusterEvaluationUtil$SaveDataAsParams.class */
    public static class SaveDataAsParams extends RichMapFunction<BaseMetricsSummary, Params> {
        private static final long serialVersionUID = -7830919170205689185L;

        public Params map(BaseMetricsSummary baseMetricsSummary) throws Exception {
            Params params = baseMetricsSummary.toMetrics().getParams();
            params.set((ParamInfo<ParamInfo<Double>>) ClusterMetrics.SILHOUETTE_COEFFICIENT, (ParamInfo<Double>) Double.valueOf(((Double) ((Tuple1) getRuntimeContext().getBroadcastVariable("silhouetteCoefficient").get(0)).f0).doubleValue() / ((Integer) params.get(ClusterMetrics.COUNT)).intValue()));
            return params;
        }
    }

    public static Params extractParamsFromConfusionMatrix(LongMatrix longMatrix, Map<Object, Integer> map, Map<Object, Integer> map2) {
        long[][] matrix = longMatrix.getMatrix();
        long[] colSums = longMatrix.getColSums();
        long[] rowSums = longMatrix.getRowSums();
        long total = longMatrix.getTotal();
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        for (long j4 : colSums) {
            d += entropy(j4, total);
            j2 += combination(j4);
        }
        double d5 = d / (-Math.log(2.0d));
        for (long j5 : rowSums) {
            d2 += entropy(j5, total);
            j3 += combination(j5);
        }
        double d6 = d2 / (-Math.log(2.0d));
        for (int i = 0; i < matrix.length; i++) {
            long j6 = 0;
            for (int i2 = 0; i2 < matrix[0].length; i2++) {
                j6 = Math.max(j6, matrix[i][i2]);
                d3 += 0 == matrix[i][i2] ? Criteria.INVALID_GAIN : ((1.0d * matrix[i][i2]) / total) * Math.log((((1.0d * total) * matrix[i][i2]) / rowSums[i]) / colSums[i2]);
                j += combination(matrix[i][i2]);
            }
            d4 += j6;
        }
        double d7 = d4 / total;
        double log = d3 / Math.log(2.0d);
        long combination = combination(total);
        long j7 = ((combination - j) - (j3 - j)) - (j2 - j);
        double d8 = ((1.0d * j2) * j3) / combination;
        double d9 = (1.0d * (j2 + j3)) / 2.0d;
        double d10 = (1.0d * (j + j7)) / (((j + j7) + r0) + r0);
        String[] strArr = new String[map.size()];
        String[] strArr2 = new String[map2.size()];
        for (Map.Entry<Object, Integer> entry : map.entrySet()) {
            strArr[entry.getValue().intValue()] = entry.getKey().toString();
        }
        for (Map.Entry<Object, Integer> entry2 : map2.entrySet()) {
            strArr2[entry2.getValue().intValue()] = entry2.getKey().toString();
        }
        return new Params().set((ParamInfo<ParamInfo<Double>>) ClusterMetrics.NMI, (ParamInfo<Double>) Double.valueOf((2.0d * log) / (d5 + d6))).set((ParamInfo<ParamInfo<Double>>) ClusterMetrics.PURITY, (ParamInfo<Double>) Double.valueOf(d7)).set((ParamInfo<ParamInfo<Double>>) ClusterMetrics.RI, (ParamInfo<Double>) Double.valueOf(d10)).set((ParamInfo<ParamInfo<Double>>) ClusterMetrics.ARI, (ParamInfo<Double>) Double.valueOf((j - d8) / (d9 - d8))).set((ParamInfo<ParamInfo<long[][]>>) ClusterMetrics.CONFUSION_MATRIX, (ParamInfo<long[][]>) matrix).set((ParamInfo<ParamInfo<String[]>>) ClusterMetrics.LABEL_ARRAY, (ParamInfo<String[]>) strArr).set((ParamInfo<ParamInfo<String[]>>) ClusterMetrics.PRED_ARRAY, (ParamInfo<String[]>) strArr2);
    }

    private static long combination(long j) {
        return (j * (j - 1)) / 2;
    }

    private static double entropy(long j, long j2) {
        double d = (1.0d * j) / j2;
        return 0 == j ? Criteria.INVALID_GAIN : d * Math.log(d);
    }

    public static Tuple1<Double> calSilhouetteCoefficient(Tuple2<Vector, String> tuple2, ClusterMetricsSummary clusterMetricsSummary) {
        String str = (String) tuple2.f1;
        Vector vector = (Vector) tuple2.f0;
        double d = 0.0d;
        double d2 = Double.MAX_VALUE;
        if (clusterMetricsSummary.distance instanceof EuclideanDistance) {
            double normL2Square = vector.normL2Square();
            for (int i = 0; i < clusterMetricsSummary.k; i++) {
                double intValue = ((clusterMetricsSummary.clusterCnt.get(i).intValue() * normL2Square) - ((2 * clusterMetricsSummary.clusterCnt.get(i).intValue()) * MatVecOp.dot(vector, clusterMetricsSummary.meanVector.get(i)))) + clusterMetricsSummary.vectorNormL2Sum.get(i).doubleValue();
                if (!str.equals(clusterMetricsSummary.clusterId.get(i))) {
                    d2 = Math.min(d2, intValue / clusterMetricsSummary.clusterCnt.get(i).intValue());
                } else if (clusterMetricsSummary.clusterCnt.get(i).intValue() > 1) {
                    d = intValue / (clusterMetricsSummary.clusterCnt.get(i).intValue() - 1);
                }
            }
        } else {
            vector.scaleEqual(1.0d / vector.normL2());
            for (int i2 = 0; i2 < clusterMetricsSummary.k; i2++) {
                double dot = 1.0d - MatVecOp.dot(vector, clusterMetricsSummary.meanVector.get(i2));
                if (!str.equals(clusterMetricsSummary.clusterId.get(i2))) {
                    d2 = Math.min(d2, dot);
                } else if (clusterMetricsSummary.clusterCnt.get(i2).intValue() > 1) {
                    d = (dot * clusterMetricsSummary.clusterCnt.get(i2).intValue()) / (clusterMetricsSummary.clusterCnt.get(i2).intValue() - 1);
                }
            }
        }
        return Tuple1.of(Double.valueOf(d < d2 ? 1.0d - (d / d2) : (d2 / d) - 1.0d));
    }

    public static Params getBasicClusterStatistics(Iterable<Row> iterable) {
        HashMap hashMap = new HashMap(0);
        int i = 0;
        for (Row row : iterable) {
            if (row != null && row.getField(0) != null) {
                i++;
                hashMap.merge(row.getField(0).toString(), Double.valueOf(1.0d), (d, d2) -> {
                    return Double.valueOf(d.doubleValue() + 1.0d);
                });
            }
        }
        int i2 = 0;
        double[] dArr = new double[hashMap.size()];
        String[] strArr = new String[hashMap.size()];
        for (Map.Entry entry : hashMap.entrySet()) {
            strArr[i2] = (String) entry.getKey();
            int i3 = i2;
            i2++;
            dArr[i3] = ((Double) entry.getValue()).doubleValue();
        }
        return new Params().set((ParamInfo<ParamInfo<Integer>>) ClusterMetrics.COUNT, (ParamInfo<Integer>) Integer.valueOf(i)).set((ParamInfo<ParamInfo<Integer>>) ClusterMetrics.K, (ParamInfo<Integer>) Integer.valueOf(hashMap.size())).set((ParamInfo<ParamInfo<String[]>>) ClusterMetrics.CLUSTER_ARRAY, (ParamInfo<String[]>) strArr).set((ParamInfo<ParamInfo<double[]>>) ClusterMetrics.COUNT_ARRAY, (ParamInfo<double[]>) dArr);
    }

    public static Tuple3<String, DenseVector, DenseVector> calMeanAndSum(Iterable<Tuple2<Vector, String>> iterable, int i, FastDistance fastDistance) {
        int i2 = 0;
        String str = null;
        DenseVector zeros = DenseVector.zeros(i);
        for (Tuple2<Vector, String> tuple2 : iterable) {
            if (null == str) {
                str = (String) tuple2.f1;
            }
            Vector vector = (Vector) tuple2.f0;
            if (fastDistance instanceof EuclideanDistance) {
                zeros.plusEqual(vector);
            } else {
                vector.scaleEqual(1.0d / vector.normL2());
                zeros.plusEqual(vector);
            }
            i2++;
        }
        DenseVector scale = zeros.scale(1.0d / i2);
        if (fastDistance instanceof CosineDistance) {
            scale.scaleEqual(1.0d / scale.normL2());
        }
        return Tuple3.of(str, scale, zeros);
    }

    public static ClusterMetricsSummary getClusterStatistics(Iterable<Tuple2<Vector, String>> iterable, ContinuousDistance continuousDistance, Tuple3<String, DenseVector, DenseVector> tuple3) {
        int i = 0;
        String str = (String) tuple3.f0;
        DenseVector denseVector = (DenseVector) tuple3.f1;
        DenseVector denseVector2 = (DenseVector) tuple3.f2;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (Tuple2<Vector, String> tuple2 : iterable) {
            double calc = continuousDistance.calc(denseVector, (Vector) tuple2.f0);
            d += calc;
            d2 += calc * calc;
            d3 += ((Vector) tuple2.f0).normL2Square();
            i++;
        }
        return new ClusterMetricsSummary(str, i, d / i, d2, d3, denseVector, continuousDistance, denseVector2);
    }
}
