package com.alibaba.alink.operator.batch.graph;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.graph.CommunityDetectionClassifyBatchOp;
import com.alibaba.alink.params.graph.CommunityDetectionClusterParams;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.graph.Edge;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRPAH_EDGES), @PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRAPH_VERTICES, isOptional = true)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "vertexCol", portIndices = {1}), @ParamSelectColumnSpec(name = "vertexWeightCol", portIndices = {1}), @ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@NameCn("标签传播聚类")
@NameEn("Common Detection Cluster")
/* loaded from: input_file:com/alibaba/alink/operator/batch/graph/CommunityDetectionClusterBatchOp.class */
public class CommunityDetectionClusterBatchOp extends BatchOperator<CommunityDetectionClusterBatchOp> implements CommunityDetectionClusterParams<CommunityDetectionClusterBatchOp> {
    private static final long serialVersionUID = 2041752348271477963L;

    public CommunityDetectionClusterBatchOp(Params params) {
        super(params);
    }

    public CommunityDetectionClusterBatchOp() {
        super(new Params());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public CommunityDetectionClusterBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        DataSet<Tuple2<String, Long>> graphNodeIdMapping;
        checkMinOpSize(1, batchOperatorArr);
        String edgeSourceCol = getEdgeSourceCol();
        String edgeTargetCol = getEdgeTargetCol();
        String edgeWeightCol = getEdgeWeightCol();
        String[] strArr = {"vertex", "label"};
        Integer maxIter = getMaxIter();
        final Boolean asUndirectedGraph = getAsUndirectedGraph();
        final boolean z = (edgeWeightCol == null || edgeWeightCol == "null") ? false : true;
        BatchOperator<?> batchOperator = batchOperatorArr[0];
        String[] strArr2 = z ? new String[]{edgeSourceCol, edgeTargetCol, edgeWeightCol} : new String[]{edgeSourceCol, edgeTargetCol};
        TypeInformation<?>[] findColTypes = TableUtil.findColTypes(batchOperator.getSchema(), strArr2);
        if (!findColTypes[0].equals(findColTypes[1])) {
            throw new RuntimeException(String.format("Edge input data, sourceCol and targetCol should be same type, sourceCol type %s, targetCol type %s", findColTypes[0], findColTypes[1]));
        }
        if (!findColTypes[0].equals(Types.STRING) && !findColTypes[0].equals(Types.LONG) && !findColTypes[0].equals(Types.INT)) {
            throw new RuntimeException(String.format("Edge input data, sourceCol and targetCol should be string, long or integer. Input type is %s", findColTypes[0]));
        }
        DataSet<Row> dataSet = batchOperator.select(strArr2).getDataSet();
        Operator operator = null;
        if (batchOperatorArr.length == 2) {
            BatchOperator<?> batchOperator2 = batchOperatorArr[1];
            String vertexCol = getVertexCol();
            String vertexWeightCol = getVertexWeightCol();
            final boolean z2 = (vertexWeightCol == null || vertexWeightCol == "null") ? false : true;
            String[] strArr3 = z2 ? new String[]{vertexCol, vertexWeightCol} : new String[]{vertexCol};
            TypeInformation<?> findColType = TableUtil.findColType(batchOperator2.getSchema(), vertexCol);
            if (!findColType.equals(findColTypes[0])) {
                throw new RuntimeException(String.format("Edge sourceCol and Vertex column should be same type, sourceCol type %s, Vertex column type %s", findColTypes[0], findColType));
            }
            DataSet<Row> dataSet2 = batchOperator2.select(strArr3).getDataSet();
            graphNodeIdMapping = GraphUtils.graphNodeIdMapping(dataSet, new int[]{0, 1}, dataSet2, 0);
            operator = GraphUtils.mapOriginalToId(dataSet2, graphNodeIdMapping, new int[]{0}).map(new MapFunction<Row, Tuple2<Long, Float>>() { // from class: com.alibaba.alink.operator.batch.graph.CommunityDetectionClusterBatchOp.1
                public Tuple2<Long, Float> map(Row row) throws Exception {
                    float f = 1.0f;
                    if (z2) {
                        f = ((Number) row.getField(1)).floatValue();
                    }
                    return Tuple2.of((Long) row.getField(0), Float.valueOf(f));
                }
            }).name("vertexRows_map_vertex_weight");
        } else {
            graphNodeIdMapping = GraphUtils.graphNodeIdMapping(dataSet, new int[]{0, 1}, null, 0);
        }
        Operator name = GraphUtils.mapOriginalToId(dataSet, graphNodeIdMapping, new int[]{0, 1}).flatMap(new FlatMapFunction<Row, Edge<Long, Float>>() { // from class: com.alibaba.alink.operator.batch.graph.CommunityDetectionClusterBatchOp.2
            public void flatMap(Row row, Collector<Edge<Long, Float>> collector) throws Exception {
                float f = 1.0f;
                if (z) {
                    f = Float.valueOf(String.valueOf(row.getField(2))).floatValue();
                }
                collector.collect(new Edge((Long) row.getField(0), (Long) row.getField(1), Float.valueOf(f)));
                if (asUndirectedGraph.booleanValue()) {
                    collector.collect(new Edge((Long) row.getField(1), (Long) row.getField(0), Float.valueOf(f)));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Edge<Long, Float>>) collector);
            }
        }).name("map_row_to_edge");
        DataSet name2 = graphNodeIdMapping.map(new MapFunction<Tuple2<String, Long>, Tuple3<Long, Long, Float>>() { // from class: com.alibaba.alink.operator.batch.graph.CommunityDetectionClusterBatchOp.3
            public Tuple3<Long, Long, Float> map(Tuple2<String, Long> tuple2) throws Exception {
                return Tuple3.of(tuple2.f1, tuple2.f1, Float.valueOf(1.0f));
            }
        }).name("init_vertex_label_map");
        if (batchOperatorArr.length == 2) {
            name2 = name2.leftOuterJoin(operator).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Tuple3<Long, Long, Float>, Tuple2<Long, Float>, Tuple3<Long, Long, Float>>() { // from class: com.alibaba.alink.operator.batch.graph.CommunityDetectionClusterBatchOp.4
                public Tuple3<Long, Long, Float> join(Tuple3<Long, Long, Float> tuple3, Tuple2<Long, Float> tuple2) throws Exception {
                    return null == tuple2 ? tuple3 : Tuple3.of(tuple3.f0, tuple3.f1, tuple2.f1);
                }
            }).name("join_origin_vertex_input_weight");
        }
        IterativeDataSet iterate = name2.iterate(maxIter.intValue());
        iterate.name("delta_iteration");
        Operator name3 = name.join(iterate).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Edge<Long, Float>, Tuple3<Long, Long, Float>, Tuple3<Long, Long, Float>>() { // from class: com.alibaba.alink.operator.batch.graph.CommunityDetectionClusterBatchOp.5
            public Tuple3<Long, Long, Float> join(Edge<Long, Float> edge, Tuple3<Long, Long, Float> tuple3) throws Exception {
                return Tuple3.of(edge.f1, tuple3.f1, Float.valueOf(((Float) edge.f2).floatValue() * ((Float) tuple3.f2).floatValue()));
            }
        }).name("join_send_messages").groupBy(new int[]{0}).reduceGroup(new CommunityDetectionClassifyBatchOp.CommunityDetection.ClusterMessageGroupFunction()).name("message_group_function").rightOuterJoin(iterate).where(new int[]{0}).equalTo(new int[]{0}).with(new CommunityDetectionClassifyBatchOp.CommunityDetection.ClusterLabelMerger()).name("join_label_merge");
        setOutput(GraphUtils.mapIdToOriginal(iterate.closeWith(name3.project(new int[]{0, 1, 2}), name3.filter(new FilterFunction<Tuple4<Long, Long, Float, Boolean>>() { // from class: com.alibaba.alink.operator.batch.graph.CommunityDetectionClusterBatchOp.6
            public boolean filter(Tuple4<Long, Long, Float, Boolean> tuple4) throws Exception {
                return ((Boolean) tuple4.f3).booleanValue();
            }
        }).project(new int[]{0})).map(new MapFunction<Tuple3<Long, Long, Float>, Row>() { // from class: com.alibaba.alink.operator.batch.graph.CommunityDetectionClusterBatchOp.7
            public Row map(Tuple3<Long, Long, Float> tuple3) throws Exception {
                Row row = new Row(2);
                row.setField(0, tuple3.f0);
                row.setField(1, tuple3.f1);
                return row;
            }
        }), graphNodeIdMapping, new int[]{0}, findColTypes[0]), strArr, new TypeInformation[]{findColTypes[0], Types.LONG});
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ CommunityDetectionClusterBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
