package com.alibaba.alink.operator.batch.graph;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.graph.ModularityCalParams;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.CrossFunction;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.ProjectOperator;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.EdgeDirection;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.NeighborsFunctionWithVertexValue;
import org.apache.flink.graph.Vertex;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRPAH_EDGES), @PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRAPH_VERTICES)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "vertexCol", portIndices = {1}), @ParamSelectColumnSpec(name = "vertexCommunityCol", portIndices = {1}, allowedTypeCollections = {TypeCollections.INT_LONG_TYPES}), @ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@NameCn("模块度计算")
@NameEn("Calculate Modularity")
/* loaded from: input_file:com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp.class */
public class ModularityCalBatchOp extends BatchOperator<ModularityCalBatchOp> implements ModularityCalParams<ModularityCalBatchOp> {
    private static final long serialVersionUID = -7765756516724178687L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp$ModularityCal.class */
    public static class ModularityCal {

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp$ModularityCal$CrossStep.class */
        protected static class CrossStep extends AbstractRichFunction implements CrossFunction<Tuple1<Long>, Tuple1<Long>, Tuple1<Double>> {
            private static final long serialVersionUID = -7359362890112928974L;
            private Tuple1<Long> mTuple;

            protected CrossStep() {
            }

            public void open(Configuration configuration) throws Exception {
                Iterator it = getRuntimeContext().getBroadcastVariable("m").iterator();
                while (it.hasNext()) {
                    this.mTuple = (Tuple1) it.next();
                }
            }

            public Tuple1<Double> cross(Tuple1<Long> tuple1, Tuple1<Long> tuple12) throws Exception {
                long longValue = ((Long) tuple1.f0).longValue();
                long longValue2 = ((Long) tuple12.f0).longValue();
                return new Tuple1<>(Double.valueOf(((1.0d * longValue) / ((Long) this.mTuple.f0).longValue()) - ((1.0d * longValue2) / (r0 * r0))));
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp$ModularityCal$ErgodicEdge.class */
        public static class ErgodicEdge implements NeighborsFunctionWithVertexValue<Long, Long, Double, Tuple3<Long, Long, Long>> {
            private static final long serialVersionUID = 5295386257754049577L;

            public void iterateNeighbors(Vertex<Long, Long> vertex, Iterable<Tuple2<Edge<Long, Double>, Vertex<Long, Long>>> iterable, Collector<Tuple3<Long, Long, Long>> collector) {
                long longValue = ((Long) vertex.f1).longValue();
                Iterator<Tuple2<Edge<Long, Double>, Vertex<Long, Long>>> it = iterable.iterator();
                while (it.hasNext()) {
                    collector.collect(Tuple3.of(Long.valueOf(longValue), Long.valueOf(((Long) ((Vertex) it.next().f1).f0).longValue()), 1L));
                }
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp$ModularityCal$FilterDiag.class */
        public static class FilterDiag implements FilterFunction<Tuple3<Long, Long, Long>> {
            private static final long serialVersionUID = 6595663411872011784L;

            public boolean filter(Tuple3<Long, Long, Long> tuple3) throws Exception {
                return ((Long) tuple3.f0).equals(tuple3.f1);
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp$ModularityCal$MapModularity.class */
        private static class MapModularity extends RichMapPartitionFunction<Tuple5<Long, Long, Long, Long, Double>, HashMap> {
            private MapModularity() {
            }

            public void mapPartition(Iterable<Tuple5<Long, Long, Long, Long, Double>> iterable, Collector<HashMap> collector) throws Exception {
                HashMap hashMap = new HashMap();
                for (Tuple5<Long, Long, Long, Long, Double> tuple5 : iterable) {
                    Tuple2 tuple2 = (Tuple2) hashMap.getOrDefault(tuple5.f1, Tuple2.of(Double.valueOf(Criteria.INVALID_GAIN), Double.valueOf(Criteria.INVALID_GAIN)));
                    if (((Long) tuple5.f1).equals(tuple5.f3)) {
                        tuple2.f0 = Double.valueOf(((Double) tuple2.f0).doubleValue() + ((Double) tuple5.f4).doubleValue());
                        hashMap.put(tuple5.f1, tuple2);
                    } else {
                        tuple2.f1 = Double.valueOf(((Double) tuple2.f1).doubleValue() + ((Double) tuple5.f4).doubleValue());
                        hashMap.put(tuple5.f1, tuple2);
                    }
                }
                collector.collect(hashMap);
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp$ModularityCal$MapSquare.class */
        public static class MapSquare implements MapFunction<Tuple1<Long>, Tuple1<Long>> {
            private static final long serialVersionUID = -1719101888137570397L;

            public Tuple1<Long> map(Tuple1<Long> tuple1) throws Exception {
                return new Tuple1<>(Long.valueOf(((Long) tuple1.f0).longValue() * ((Long) tuple1.f0).longValue()));
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp$ModularityCal$ReduceOnCommunity.class */
        public static class ReduceOnCommunity implements ReduceFunction<Tuple3<Long, Long, Long>> {
            private static final long serialVersionUID = 3502336992662864358L;

            public Tuple3<Long, Long, Long> reduce(Tuple3<Long, Long, Long> tuple3, Tuple3<Long, Long, Long> tuple32) throws Exception {
                return new Tuple3<>(tuple3.f0, tuple3.f1, Long.valueOf(((Long) tuple3.f2).longValue() + ((Long) tuple32.f2).longValue()));
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp$ModularityCal$SelectTuple.class */
        public static class SelectTuple implements KeySelector<Tuple3<Long, Long, Long>, Tuple2<Long, Long>> {
            private static final long serialVersionUID = 5638365638596494304L;

            public Tuple2<Long, Long> getKey(Tuple3<Long, Long, Long> tuple3) throws Exception {
                return Tuple2.of(tuple3.f0, tuple3.f1);
            }
        }

        public static DataSet<Tuple1<Double>> modularity(DataSet<Tuple3<Long, Double, Double>> dataSet) {
            return dataSet.groupBy(new int[]{0}).aggregate(Aggregations.SUM, 1).and(Aggregations.SUM, 2).reduceGroup(new RichGroupReduceFunction<Tuple3<Long, Double, Double>, Tuple1<Double>>() { // from class: com.alibaba.alink.operator.batch.graph.ModularityCalBatchOp.ModularityCal.1
                public void reduce(Iterable<Tuple3<Long, Double, Double>> iterable, Collector<Tuple1<Double>> collector) throws Exception {
                    double doubleValue = ((Double) ((Tuple1) getRuntimeContext().getBroadcastVariable("m").get(0)).f0).doubleValue();
                    double d = 0.0d;
                    double d2 = 0.0d;
                    for (Tuple3<Long, Double, Double> tuple3 : iterable) {
                        d += ((Double) tuple3.f1).doubleValue();
                        d2 += Math.pow(((Double) tuple3.f2).doubleValue(), 2.0d);
                    }
                    collector.collect(Tuple1.of(Double.valueOf((d / doubleValue) - (d2 / Math.pow(doubleValue, 2.0d)))));
                }
            }).withBroadcastSet(dataSet.aggregate(Aggregations.SUM, 2).project(new int[]{2}), "m");
        }

        public static DataSet<Tuple1<Double>> run2(Graph<Long, Long, Double> graph) {
            DataSet groupReduceOnNeighbors = graph.groupReduceOnNeighbors(new NeighborsFunctionWithVertexValue<Long, Long, Double, Tuple5<Long, Long, Long, Long, Double>>() { // from class: com.alibaba.alink.operator.batch.graph.ModularityCalBatchOp.ModularityCal.2
                public void iterateNeighbors(Vertex<Long, Long> vertex, Iterable<Tuple2<Edge<Long, Double>, Vertex<Long, Long>>> iterable, Collector<Tuple5<Long, Long, Long, Long, Double>> collector) throws Exception {
                    for (Tuple2<Edge<Long, Double>, Vertex<Long, Long>> tuple2 : iterable) {
                        collector.collect(Tuple5.of(vertex.f0, vertex.f1, ((Vertex) tuple2.f1).f0, ((Vertex) tuple2.f1).f1, ((Edge) tuple2.f0).f2));
                    }
                }
            }, EdgeDirection.OUT);
            return groupReduceOnNeighbors.mapPartition(new MapModularity()).reduce(new ReduceFunction<HashMap>() { // from class: com.alibaba.alink.operator.batch.graph.ModularityCalBatchOp.ModularityCal.3
                public HashMap reduce(HashMap hashMap, HashMap hashMap2) throws Exception {
                    for (Map.Entry entry : hashMap2.entrySet()) {
                        Tuple2 tuple2 = (Tuple2) hashMap.getOrDefault(entry.getKey(), Tuple2.of(Double.valueOf(Criteria.INVALID_GAIN), Double.valueOf(Criteria.INVALID_GAIN)));
                        tuple2.f0 = Double.valueOf(((Double) tuple2.f0).doubleValue() + ((Double) ((Tuple2) entry.getValue()).f0).doubleValue());
                        tuple2.f1 = Double.valueOf(((Double) tuple2.f1).doubleValue() + ((Double) ((Tuple2) entry.getValue()).f1).doubleValue());
                        hashMap.put(entry.getKey(), tuple2);
                    }
                    return hashMap;
                }
            }).map(new RichMapFunction<HashMap, Tuple1<Double>>() { // from class: com.alibaba.alink.operator.batch.graph.ModularityCalBatchOp.ModularityCal.4
                public Tuple1<Double> map(HashMap hashMap) throws Exception {
                    double doubleValue = ((Double) ((Tuple1) getRuntimeContext().getBroadcastVariable("m").get(0)).f0).doubleValue();
                    double d = 0.0d;
                    double d2 = 0.0d;
                    for (Map.Entry entry : hashMap.entrySet()) {
                        d += ((Double) ((Tuple2) entry.getValue()).f0).doubleValue();
                        d2 += Math.pow(((Double) ((Tuple2) entry.getValue()).f0).doubleValue() + ((Double) ((Tuple2) entry.getValue()).f1).doubleValue(), 2.0d);
                    }
                    return Tuple1.of(Double.valueOf((d / doubleValue) - (d2 / Math.pow(doubleValue, 2.0d))));
                }
            }).withBroadcastSet(groupReduceOnNeighbors.aggregate(Aggregations.SUM, 4).project(new int[]{4}), "m");
        }

        public static DataSet<Tuple1<Double>> run(Graph<Long, Long, Double> graph) {
            ReduceOperator reduce = graph.groupReduceOnNeighbors(new ErgodicEdge(), EdgeDirection.OUT).groupBy(new SelectTuple()).reduce(new ReduceOnCommunity());
            ProjectOperator project = reduce.groupBy(new int[]{1}).aggregate(Aggregations.SUM, 2).project(new int[]{2});
            return reduce.filter(new FilterDiag()).aggregate(Aggregations.SUM, 2).project(new int[]{2}).cross(project.map(new MapSquare()).aggregate(Aggregations.SUM, 0)).with(new CrossStep()).withBroadcastSet(project.aggregate(Aggregations.SUM, 0), "m");
        }
    }

    public ModularityCalBatchOp(Params params) {
        super(params);
    }

    public ModularityCalBatchOp() {
        super(new Params());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public ModularityCalBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        checkOpSize(2, batchOperatorArr);
        String edgeSourceCol = getEdgeSourceCol();
        String edgeTargetCol = getEdgeTargetCol();
        String edgeWeightCol = getEdgeWeightCol();
        boolean z = edgeWeightCol != null;
        String vertexCol = getVertexCol();
        String vertexCommunityCol = getVertexCommunityCol();
        Boolean asUndirectedGraph = getAsUndirectedGraph();
        String[] strArr = z ? new String[]{edgeSourceCol, edgeTargetCol, edgeWeightCol} : new String[]{edgeSourceCol, edgeTargetCol};
        String[] strArr2 = {vertexCol, vertexCommunityCol};
        DataSet<Row> dataSet = batchOperatorArr[0].select(strArr).getDataSet();
        DataSet<Row> dataSet2 = batchOperatorArr[1].select(strArr2).getDataSet();
        DataSet<Tuple2<String, Long>> graphNodeIdMapping = GraphUtils.graphNodeIdMapping(dataSet2, new int[]{1}, null, 0);
        DataSet<Edge<String, Double>> rowToEdges = GraphUtils.rowToEdges(dataSet, z, asUndirectedGraph.booleanValue());
        MapOperator map = GraphUtils.mapOriginalToId(dataSet2, graphNodeIdMapping, new int[]{1}).map(new MapFunction<Row, Vertex<String, Long>>() { // from class: com.alibaba.alink.operator.batch.graph.ModularityCalBatchOp.1
            public Vertex<String, Long> map(Row row) throws Exception {
                return new Vertex<>(String.valueOf(row.getField(0)), (Long) row.getField(1));
            }
        });
        setOutput(ModularityCal.modularity(rowToEdges.join(map).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Edge<String, Double>, Vertex<String, Long>, Tuple3<Long, String, Double>>() { // from class: com.alibaba.alink.operator.batch.graph.ModularityCalBatchOp.3
            public Tuple3<Long, String, Double> join(Edge<String, Double> edge, Vertex<String, Long> vertex) throws Exception {
                return Tuple3.of(vertex.f1, edge.f1, edge.f2);
            }
        }).join(map).where(new int[]{1}).equalTo(new int[]{0}).with(new JoinFunction<Tuple3<Long, String, Double>, Vertex<String, Long>, Tuple3<Long, Double, Double>>() { // from class: com.alibaba.alink.operator.batch.graph.ModularityCalBatchOp.2
            public Tuple3<Long, Double, Double> join(Tuple3<Long, String, Double> tuple3, Vertex<String, Long> vertex) throws Exception {
                return ((Long) tuple3.f0).equals(vertex.f1) ? Tuple3.of(tuple3.f0, tuple3.f2, tuple3.f2) : Tuple3.of(tuple3.f0, Double.valueOf(Criteria.INVALID_GAIN), tuple3.f2);
            }
        })).map(new MapFunction<Tuple1<Double>, Row>() { // from class: com.alibaba.alink.operator.batch.graph.ModularityCalBatchOp.4
            private static final long serialVersionUID = -6930741669256207575L;

            public Row map(Tuple1<Double> tuple1) throws Exception {
                Row row = new Row(1);
                row.setField(0, tuple1.f0);
                return row;
            }
        }), new String[]{"modularity"}, new TypeInformation[]{Types.DOUBLE});
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ ModularityCalBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
