package com.alibaba.alink.operator.batch.graph;

import com.alibaba.alink.common.utils.JsonConverter;
import java.util.Iterator;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/batch/graph/GraphUtils.class */
public class GraphUtils {
    public static DataSet<Edge<String, Double>> rowToEdges(DataSet<Row> dataSet, final boolean z, final boolean z2) {
        return dataSet.flatMap(new FlatMapFunction<Row, Edge<String, Double>>() { // from class: com.alibaba.alink.operator.batch.graph.GraphUtils.1
            public void flatMap(Row row, Collector<Edge<String, Double>> collector) throws Exception {
                double d = 1.0d;
                if (z) {
                    d = Double.valueOf(String.valueOf(row.getField(2))).doubleValue();
                }
                collector.collect(new Edge(String.valueOf(row.getField(0)), String.valueOf(row.getField(1)), Double.valueOf(d)));
                if (z2) {
                    collector.collect(new Edge(String.valueOf(row.getField(1)), String.valueOf(row.getField(0)), Double.valueOf(d)));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Edge<String, Double>>) collector);
            }
        });
    }

    public static DataSet<Tuple2<String, Long>> graphNodeIdMapping(DataSet<Row> dataSet, final int[] iArr, DataSet<Row> dataSet2, final int i) {
        DataSet name = dataSet.flatMap(new FlatMapFunction<Row, String>() { // from class: com.alibaba.alink.operator.batch.graph.GraphUtils.2
            public void flatMap(Row row, Collector<String> collector) throws Exception {
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    collector.collect(JsonConverter.toJson(row.getField(iArr[i2])));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<String>) collector);
            }
        }).name("graphNodeIdMapping_flatmap_edge_nodes");
        if (null != dataSet2) {
            name = name.union(dataSet2.map(new MapFunction<Row, String>() { // from class: com.alibaba.alink.operator.batch.graph.GraphUtils.3
                public String map(Row row) throws Exception {
                    return JsonConverter.toJson(row.getField(i));
                }
            })).name("graphNodeIdMapping_map_vertex_nodes");
        }
        return name.distinct().name("graphNodeIdMapping_distinct_nodes").map(new RichMapFunction<String, Tuple2<String, Long>>() { // from class: com.alibaba.alink.operator.batch.graph.GraphUtils.4
            long cnt;
            int numTasks;
            int taskId;

            public void open(Configuration configuration) throws Exception {
                this.cnt = 0L;
                this.numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
                this.taskId = getRuntimeContext().getIndexOfThisSubtask();
            }

            public Tuple2<String, Long> map(String str) throws Exception {
                long j = this.numTasks;
                long j2 = this.cnt;
                this.cnt = j2 + 1;
                return Tuple2.of(str, Long.valueOf((j * j2) + this.taskId));
            }
        }).name("build_node_mapping");
    }

    public static DataSet<Row> mapOriginalToId(DataSet<Row> dataSet, DataSet<Tuple2<String, Long>> dataSet2, int[] iArr) {
        DataSet<Row> dataSet3 = dataSet;
        for (final int i : iArr) {
            dataSet3 = dataSet3.coGroup(dataSet2).where(new KeySelector<Row, String>() { // from class: com.alibaba.alink.operator.batch.graph.GraphUtils.6
                public String getKey(Row row) throws Exception {
                    return JsonConverter.toJson(row.getField(i));
                }
            }).equalTo(new int[]{0}).with(new CoGroupFunction<Row, Tuple2<String, Long>, Row>() { // from class: com.alibaba.alink.operator.batch.graph.GraphUtils.5
                public void coGroup(Iterable<Row> iterable, Iterable<Tuple2<String, Long>> iterable2, Collector<Row> collector) throws Exception {
                    Iterator<Tuple2<String, Long>> it = iterable2.iterator();
                    if (it.hasNext()) {
                        long longValue = ((Long) it.next().f1).longValue();
                        for (Row row : iterable) {
                            row.setField(i, Long.valueOf(longValue));
                            collector.collect(row);
                        }
                    }
                }
            }).name("mapOriginToId_cogroup_index_" + i);
        }
        return dataSet3;
    }

    public static DataSet<Row> mapIdToOriginal(DataSet<Row> dataSet, DataSet<Tuple2<String, Long>> dataSet2, int[] iArr, final TypeInformation typeInformation) {
        DataSet<Row> dataSet3 = dataSet;
        for (final int i : iArr) {
            dataSet3 = dataSet3.coGroup(dataSet2).where(new KeySelector<Row, Long>() { // from class: com.alibaba.alink.operator.batch.graph.GraphUtils.8
                public Long getKey(Row row) throws Exception {
                    return (Long) row.getField(i);
                }
            }).equalTo(new int[]{1}).with(new CoGroupFunction<Row, Tuple2<String, Long>, Row>() { // from class: com.alibaba.alink.operator.batch.graph.GraphUtils.7
                public void coGroup(Iterable<Row> iterable, Iterable<Tuple2<String, Long>> iterable2, Collector<Row> collector) throws Exception {
                    Iterator<Tuple2<String, Long>> it = iterable2.iterator();
                    if (it.hasNext()) {
                        String str = (String) it.next().f0;
                        for (Row row : iterable) {
                            row.setField(i, JsonConverter.fromJson(str, typeInformation.getTypeClass()));
                            collector.collect(row);
                        }
                    }
                }
            }).name("mapIdToOrigin_cogroup_index_" + i);
        }
        return dataSet3;
    }
}
