package com.alibaba.alink.operator.batch.evaluation;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary;
import com.alibaba.alink.operator.common.evaluation.ClusterEvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.ClusterMetrics;
import com.alibaba.alink.operator.common.evaluation.ClusterMetricsSummary;
import com.alibaba.alink.operator.common.evaluation.EvaluationMetricsCollector;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.LongMatrix;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.operator.common.statistics.basicstatistic.SparseVectorSummary;
import com.alibaba.alink.params.evaluation.EvalClusterParams;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichCoGroupFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.operators.ProjectOperator;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.EVAL_METRICS)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "labelCol"), @ParamSelectColumnSpec(name = "predictionCol"), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES})})
@NameCn("聚类评估")
@NameEn("Eval Cluster")
/* loaded from: input_file:com/alibaba/alink/operator/batch/evaluation/EvalClusterBatchOp.class */
public final class EvalClusterBatchOp extends BatchOperator<EvalClusterBatchOp> implements EvalClusterParams<EvalClusterBatchOp>, EvaluationMetricsCollector<ClusterMetrics, EvalClusterBatchOp> {
    public static final String SILHOUETTE_COEFFICIENT = "silhouetteCoefficient";
    private static final String METRICS_SUMMARY = "metricsSummary";
    private static final String EVAL_RESULT = "cluster_eval_result";
    private static final String LABELS = "labels";
    private static final String PREDICTIONS = "predictions";
    private static final String VECTOR_SIZE = "vectorSize";
    private static final String MEAN_AND_SUM = "meanAndSum";
    private static final long serialVersionUID = -1334962642325725386L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/evaluation/EvalClusterBatchOp$BasicClusterParams.class */
    public static class BasicClusterParams implements GroupReduceFunction<Row, Params> {
        private static final long serialVersionUID = 6863171040793536672L;

        public void reduce(Iterable<Row> iterable, Collector<Params> collector) {
            collector.collect(ClusterEvaluationUtil.getBasicClusterStatistics(iterable));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/evaluation/EvalClusterBatchOp$CalLocalPredResult.class */
    public static class CalLocalPredResult extends RichMapPartitionFunction<Row, LongMatrix> {
        private static final long serialVersionUID = -3838344564725765751L;
        private Map<Object, Integer> labels;
        private Map<Object, Integer> predictions;

        CalLocalPredResult() {
        }

        public void open(Configuration configuration) throws Exception {
            this.labels = (Map) ((Tuple1) getRuntimeContext().getBroadcastVariable(EvalClusterBatchOp.LABELS).get(0)).f0;
            this.predictions = (Map) ((Tuple1) getRuntimeContext().getBroadcastVariable(EvalClusterBatchOp.PREDICTIONS).get(0)).f0;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<LongMatrix> collector) {
            long[][] jArr = new long[this.predictions.size()][this.labels.size()];
            for (Row row : iterable) {
                if (EvaluationUtil.checkRowFieldNotNull(row)) {
                    int intValue = this.labels.get(row.getField(0)).intValue();
                    long[] jArr2 = jArr[this.predictions.get(row.getField(1)).intValue()];
                    jArr2[intValue] = jArr2[intValue] + 1;
                }
            }
            collector.collect(new LongMatrix(jArr));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/evaluation/EvalClusterBatchOp$CalcClusterMetricsSummary.class */
    public static class CalcClusterMetricsSummary extends RichCoGroupFunction<Tuple2<Vector, String>, Tuple3<String, DenseVector, DenseVector>, BaseMetricsSummary> {
        private static final long serialVersionUID = 346446456425064132L;
        private FastDistance distance;

        public CalcClusterMetricsSummary(FastDistance fastDistance) {
            this.distance = fastDistance;
        }

        public void coGroup(Iterable<Tuple2<Vector, String>> iterable, Iterable<Tuple3<String, DenseVector, DenseVector>> iterable2, Collector<BaseMetricsSummary> collector) {
            Tuple3<String, DenseVector, DenseVector> tuple3 = null;
            Iterator<Tuple3<String, DenseVector, DenseVector>> it = iterable2.iterator();
            while (it.hasNext()) {
                tuple3 = it.next();
            }
            collector.collect(ClusterEvaluationUtil.getClusterStatistics(iterable, this.distance, tuple3));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/evaluation/EvalClusterBatchOp$CalcMeanAndSum.class */
    public static class CalcMeanAndSum extends RichGroupReduceFunction<Tuple2<Vector, String>, Tuple3<String, DenseVector, DenseVector>> {
        private static final long serialVersionUID = 346446456425064132L;
        private int vectorSize;
        private FastDistance distance;

        public CalcMeanAndSum(FastDistance fastDistance) {
            this.distance = fastDistance;
        }

        public void open(Configuration configuration) {
            this.vectorSize = ((BaseVectorSummary) getRuntimeContext().getBroadcastVariable("vectorSize").get(0)).vectorSize();
        }

        public void reduce(Iterable<Tuple2<Vector, String>> iterable, Collector<Tuple3<String, DenseVector, DenseVector>> collector) {
            collector.collect(ClusterEvaluationUtil.calMeanAndSum(iterable, this.vectorSize, this.distance));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/evaluation/EvalClusterBatchOp$FilterEmptyRow.class */
    public static class FilterEmptyRow extends RichFlatMapFunction<Tuple2<Vector, Row>, Tuple2<Vector, String>> {
        private static final long serialVersionUID = 4239894668365119029L;
        private int vectorSize;
        private boolean isSparse;

        private FilterEmptyRow() {
        }

        public void open(Configuration configuration) {
            BaseVectorSummary baseVectorSummary = (BaseVectorSummary) getRuntimeContext().getBroadcastVariable("vectorSize").get(0);
            this.vectorSize = baseVectorSummary.vectorSize();
            this.isSparse = baseVectorSummary instanceof SparseVectorSummary;
        }

        public void flatMap(Tuple2<Vector, Row> tuple2, Collector<Tuple2<Vector, String>> collector) throws Exception {
            if (tuple2.f0 == null || ((Row) tuple2.f1).getField(0) == null) {
                return;
            }
            if (this.isSparse) {
                ((SparseVector) tuple2.f0).setSize(this.vectorSize);
            }
            collector.collect(Tuple2.of(tuple2.f0, ((Row) tuple2.f1).getField(0).toString()));
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple2<Vector, Row>) obj, (Collector<Tuple2<Vector, String>>) collector);
        }
    }

    public EvalClusterBatchOp() {
        super(null);
    }

    public EvalClusterBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.common.evaluation.EvaluationMetricsCollector
    public ClusterMetrics createMetrics(List<Row> list) {
        return new ClusterMetrics(list.get(0));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public EvalClusterBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        SingleInputUdfOperator reduceGroup;
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String labelCol = getLabelCol();
        String predictionCol = getPredictionCol();
        String vectorCol = getVectorCol();
        FastDistance fastDistance = getDistanceType().getFastDistance();
        SingleInputUdfOperator fromElements = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Params[]{new Params()});
        if (null != labelCol) {
            DataSet<Row> dataSet = checkAndGetFirst.select(new String[]{labelCol, predictionCol}).getDataSet();
            ProjectOperator project = dataSet.flatMap(new FlatMapFunction<Row, Object>() { // from class: com.alibaba.alink.operator.batch.evaluation.EvalClusterBatchOp.1
                private static final long serialVersionUID = 6181506719667975996L;

                public void flatMap(Row row, Collector<Object> collector) {
                    if (EvaluationUtil.checkRowFieldNotNull(row)) {
                        collector.collect(row.getField(0));
                    }
                }

                public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                    flatMap((Row) obj, (Collector<Object>) collector);
                }
            }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, null, null, false)).project(new int[]{0});
            ProjectOperator project2 = dataSet.flatMap(new FlatMapFunction<Row, Object>() { // from class: com.alibaba.alink.operator.batch.evaluation.EvalClusterBatchOp.2
                private static final long serialVersionUID = 619373417169823128L;

                public void flatMap(Row row, Collector<Object> collector) {
                    if (EvaluationUtil.checkRowFieldNotNull(row)) {
                        collector.collect(row.getField(1));
                    }
                }

                public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                    flatMap((Row) obj, (Collector<Object>) collector);
                }
            }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, null, null, false)).project(new int[]{0});
            fromElements = dataSet.rebalance().mapPartition(new CalLocalPredResult()).withBroadcastSet(project, LABELS).withBroadcastSet(project2, PREDICTIONS).reduce(new ReduceFunction<LongMatrix>() { // from class: com.alibaba.alink.operator.batch.evaluation.EvalClusterBatchOp.4
                private static final long serialVersionUID = 3340266128816528106L;

                public LongMatrix reduce(LongMatrix longMatrix, LongMatrix longMatrix2) {
                    longMatrix.plusEqual(longMatrix2);
                    return longMatrix;
                }
            }).map(new RichMapFunction<LongMatrix, Params>() { // from class: com.alibaba.alink.operator.batch.evaluation.EvalClusterBatchOp.3
                private static final long serialVersionUID = -4218363116865487327L;

                public Params map(LongMatrix longMatrix) {
                    return ClusterEvaluationUtil.extractParamsFromConfusionMatrix(longMatrix, (Map) ((Tuple1) getRuntimeContext().getBroadcastVariable(EvalClusterBatchOp.LABELS).get(0)).f0, (Map) ((Tuple1) getRuntimeContext().getBroadcastVariable(EvalClusterBatchOp.PREDICTIONS).get(0)).f0);
                }
            }).withBroadcastSet(project, LABELS).withBroadcastSet(project2, PREDICTIONS);
        }
        if (null != vectorCol) {
            Tuple2<DataSet<Tuple2<Vector, Row>>, DataSet<BaseVectorSummary>> summaryHelper = StatisticsHelper.summaryHelper(batchOperatorArr[0], null, vectorCol, new String[]{predictionCol});
            SingleInputUdfOperator withBroadcastSet = ((DataSet) summaryHelper.f0).flatMap(new FilterEmptyRow()).withBroadcastSet((DataSet) summaryHelper.f1, "vectorSize");
            SingleInputUdfOperator withBroadcastSet2 = withBroadcastSet.groupBy(new int[]{1}).reduceGroup(new CalcMeanAndSum(fastDistance)).withBroadcastSet((DataSet) summaryHelper.f1, "vectorSize");
            ReduceOperator reduce = withBroadcastSet.coGroup(withBroadcastSet2).where(new int[]{1}).equalTo(new int[]{0}).with(new CalcClusterMetricsSummary(fastDistance)).withBroadcastSet(withBroadcastSet2, MEAN_AND_SUM).reduce(new EvaluationUtil.ReduceBaseMetrics());
            reduceGroup = reduce.map(new ClusterEvaluationUtil.SaveDataAsParams()).withBroadcastSet(withBroadcastSet.map(new RichMapFunction<Tuple2<Vector, String>, Tuple1<Double>>() { // from class: com.alibaba.alink.operator.batch.evaluation.EvalClusterBatchOp.5
                private static final long serialVersionUID = 116926378586242272L;

                public Tuple1<Double> map(Tuple2<Vector, String> tuple2) {
                    return ClusterEvaluationUtil.calSilhouetteCoefficient(tuple2, (ClusterMetricsSummary) getRuntimeContext().getBroadcastVariable(EvalClusterBatchOp.METRICS_SUMMARY).get(0));
                }
            }).withBroadcastSet(reduce, METRICS_SUMMARY).aggregate(Aggregations.SUM, 0), "silhouetteCoefficient");
        } else {
            reduceGroup = checkAndGetFirst.select(predictionCol).getDataSet().reduceGroup(new BasicClusterParams());
        }
        setOutputTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) fromElements.union(reduceGroup).reduceGroup(new GroupReduceFunction<Params, Row>() { // from class: com.alibaba.alink.operator.batch.evaluation.EvalClusterBatchOp.6
            private static final long serialVersionUID = -4726713311986089251L;

            public void reduce(Iterable<Params> iterable, Collector<Row> collector) {
                Params params = new Params();
                Iterator<Params> it = iterable.iterator();
                while (it.hasNext()) {
                    params.merge(it.next());
                }
                collector.collect(Row.of(new Object[]{params.toJson()}));
            }
        }), new TableSchema(new String[]{EVAL_RESULT}, new TypeInformation[]{Types.STRING})));
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ EvalClusterBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }

    @Override // com.alibaba.alink.operator.common.evaluation.EvaluationMetricsCollector
    public /* bridge */ /* synthetic */ ClusterMetrics createMetrics(List list) {
        return createMetrics((List<Row>) list);
    }
}
