package com.alibaba.alink.operator.local.evaluation;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.evaluation.ClassificationEvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.ClusterEvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.ClusterMetrics;
import com.alibaba.alink.operator.common.evaluation.ClusterMetricsSummary;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.LongMatrix;
import com.alibaba.alink.operator.common.statistics.basicstatistic.DenseVectorSummarizer;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.params.evaluation.EvalClusterParams;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.EVAL_METRICS)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "labelCol"), @ParamSelectColumnSpec(name = "predictionCol"), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES})})
@NameCn("聚类评估")
/* loaded from: input_file:com/alibaba/alink/operator/local/evaluation/EvalClusterLocalOp.class */
public final class EvalClusterLocalOp extends LocalOperator<EvalClusterLocalOp> implements EvalClusterParams<EvalClusterLocalOp>, EvaluationMetricsCollector<ClusterMetrics, EvalClusterLocalOp> {
    public static final String SILHOUETTE_COEFFICIENT = "silhouetteCoefficient";
    private static final String METRICS_SUMMARY = "metricsSummary";
    private static final String EVAL_RESULT = "cluster_eval_result";
    private static final String LABELS = "labels";
    private static final String VECTOR_SIZE = "vectorSize";
    private static final String MEAN_AND_SUM = "meanAndSum";
    private static final long serialVersionUID = -1334962642325725386L;

    public EvalClusterLocalOp() {
        super(null);
    }

    public EvalClusterLocalOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.local.evaluation.EvaluationMetricsCollector
    public ClusterMetrics createMetrics(List<Row> list) {
        return new ClusterMetrics(list.get(0));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.local.LocalOperator
    public EvalClusterLocalOp linkFrom(LocalOperator<?>... localOperatorArr) {
        Params basicClusterStatistics;
        LocalOperator<?> checkAndGetFirst = checkAndGetFirst(localOperatorArr);
        if (0 == checkAndGetFirst.getOutputTable().getRows().size()) {
            setOutputTable(new MTable(new Row[]{Row.of(new Object[]{ClusterMetricsSummary.createForEmptyDataset().getParams().toJson()})}, new TableSchema(new String[]{EVAL_RESULT}, new TypeInformation[]{Types.STRING})));
            return this;
        }
        String labelCol = getLabelCol();
        String predictionCol = getPredictionCol();
        String vectorCol = getVectorCol();
        FastDistance fastDistance = getDistanceType().getFastDistance();
        Params params = new Params();
        new Params();
        if (null != labelCol) {
            MTable outputTable = checkAndGetFirst.select(new String[]{labelCol, predictionCol}).getOutputTable();
            HashSet hashSet = new HashSet();
            for (Row row : outputTable.getRows()) {
                if (EvaluationUtil.checkRowFieldNotNull(row)) {
                    hashSet.add(row.getField(0));
                }
            }
            Map map = (Map) ClassificationEvaluationUtil.buildLabelIndexLabelArray(hashSet, false, null, null, false).f0;
            HashSet hashSet2 = new HashSet();
            for (Row row2 : outputTable.getRows()) {
                if (EvaluationUtil.checkRowFieldNotNull(row2)) {
                    hashSet2.add(row2.getField(1));
                }
            }
            Map map2 = (Map) ClassificationEvaluationUtil.buildLabelIndexLabelArray(hashSet2, false, null, null, false).f0;
            long[][] jArr = new long[map2.size()][map.size()];
            for (Row row3 : outputTable.getRows()) {
                if (EvaluationUtil.checkRowFieldNotNull(row3)) {
                    int intValue = ((Integer) map.get(row3.getField(0))).intValue();
                    long[] jArr2 = jArr[((Integer) map2.get(row3.getField(1))).intValue()];
                    jArr2[intValue] = jArr2[intValue] + 1;
                }
            }
            params = ClusterEvaluationUtil.extractParamsFromConfusionMatrix(new LongMatrix(jArr), map, map2);
        }
        if (null != vectorCol) {
            List<Tuple2> list = (List) checkAndGetFirst.select(new String[]{vectorCol, predictionCol}).getOutputTable().getRows().stream().map(row4 -> {
                return Tuple2.of(VectorUtil.getVector(row4.getField(0)), row4.getField(1).toString());
            }).collect(Collectors.toList());
            DenseVectorSummarizer denseVectorSummarizer = new DenseVectorSummarizer(false);
            Iterator it = list.iterator();
            while (it.hasNext()) {
                denseVectorSummarizer.visit((Vector) ((Tuple2) it.next()).getField(0));
            }
            int vectorSize = denseVectorSummarizer.toSummary().vectorSize();
            HashMap hashMap = new HashMap();
            for (Tuple2 tuple2 : list) {
                String str = (String) tuple2.getField(1);
                if (!hashMap.containsKey(str)) {
                    hashMap.put(str, new ArrayList());
                }
                ((List) hashMap.get(str)).add(Tuple2.of(tuple2.getField(0), str));
            }
            Map map3 = (Map) hashMap.entrySet().stream().map(entry -> {
                return Tuple2.of(entry.getKey(), ClusterEvaluationUtil.calMeanAndSum((Iterable) entry.getValue(), vectorSize, fastDistance));
            }).collect(Collectors.toMap(tuple22 -> {
                return (String) tuple22.f0;
            }, tuple23 -> {
                return (Tuple3) tuple23.f1;
            }));
            List list2 = (List) hashMap.entrySet().stream().map(entry2 -> {
                return ClusterEvaluationUtil.getClusterStatistics((Iterable) entry2.getValue(), fastDistance, (Tuple3) map3.get(entry2.getKey()));
            }).collect(Collectors.toList());
            ClusterMetricsSummary clusterMetricsSummary = (ClusterMetricsSummary) list2.get(0);
            for (int i = 1; i < list2.size(); i++) {
                clusterMetricsSummary.merge((ClusterMetricsSummary) list2.get(i));
            }
            double d = 0.0d;
            Iterator it2 = list.iterator();
            while (it2.hasNext()) {
                d += ((Double) ClusterEvaluationUtil.calSilhouetteCoefficient((Tuple2) it2.next(), clusterMetricsSummary).f0).doubleValue();
            }
            basicClusterStatistics = clusterMetricsSummary.toMetrics().getParams();
            basicClusterStatistics.set("silhouetteCoefficient", Double.valueOf(d / clusterMetricsSummary.toMetrics().getCount().intValue()));
        } else {
            basicClusterStatistics = ClusterEvaluationUtil.getBasicClusterStatistics(checkAndGetFirst.select(predictionCol).getOutputTable().getRows());
        }
        Params m1495clone = params.m1495clone();
        m1495clone.merge(basicClusterStatistics);
        setOutputTable(new MTable(new Row[]{Row.of(new Object[]{m1495clone.toJson()})}, new TableSchema(new String[]{EVAL_RESULT}, new TypeInformation[]{Types.STRING})));
        return this;
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public /* bridge */ /* synthetic */ EvalClusterLocalOp linkFrom(LocalOperator[] localOperatorArr) {
        return linkFrom((LocalOperator<?>[]) localOperatorArr);
    }

    @Override // com.alibaba.alink.operator.local.evaluation.EvaluationMetricsCollector
    public /* bridge */ /* synthetic */ ClusterMetrics createMetrics(List list) {
        return createMetrics((List<Row>) list);
    }
}
