package com.alibaba.alink.operator.batch.graph.utils;

import java.util.Iterator;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/batch/graph/utils/IDMappingUtils.class */
public class IDMappingUtils {
    public static DataSet<Tuple2<String, Long>> computeIdMapping(DataSet<Row> dataSet, final int[] iArr) {
        return dataSet.mapPartition(new MapPartitionFunction<Row, String>() { // from class: com.alibaba.alink.operator.batch.graph.utils.IDMappingUtils.2
            public void mapPartition(Iterable<Row> iterable, Collector<String> collector) throws Exception {
                for (Row row : iterable) {
                    for (int i = 0; i < iArr.length; i++) {
                        collector.collect((String) row.getField(iArr[i]));
                    }
                }
            }
        }).distinct().mapPartition(new RichMapPartitionFunction<String, Tuple2<String, Long>>() { // from class: com.alibaba.alink.operator.batch.graph.utils.IDMappingUtils.1
            public void mapPartition(Iterable<String> iterable, Collector<Tuple2<String, Long>> collector) {
                long j = 0;
                int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                Iterator<String> it = iterable.iterator();
                while (it.hasNext()) {
                    long j2 = j;
                    j = j2 + 1;
                    collector.collect(Tuple2.of(it.next(), Long.valueOf((numberOfParallelSubtasks * j2) + indexOfThisSubtask)));
                }
            }
        }).name("build_node_mapping");
    }

    public static DataSet<Row> mapDataSetWithIdMapping(DataSet<Row> dataSet, DataSet<Tuple2<String, Long>> dataSet2, int[] iArr) {
        DataSet<Row> dataSet3 = dataSet;
        for (final int i : iArr) {
            dataSet3 = dataSet3.coGroup(dataSet2).where(new KeySelector<Row, String>() { // from class: com.alibaba.alink.operator.batch.graph.utils.IDMappingUtils.4
                public String getKey(Row row) throws Exception {
                    return (String) row.getField(i);
                }
            }).equalTo(new int[]{0}).with(new CoGroupFunction<Row, Tuple2<String, Long>, Row>() { // from class: com.alibaba.alink.operator.batch.graph.utils.IDMappingUtils.3
                public void coGroup(Iterable<Row> iterable, Iterable<Tuple2<String, Long>> iterable2, Collector<Row> collector) throws Exception {
                    Iterator<Tuple2<String, Long>> it = iterable2.iterator();
                    if (it.hasNext()) {
                        long longValue = ((Long) it.next().f1).longValue();
                        for (Row row : iterable) {
                            row.setField(i, Long.valueOf(longValue));
                            collector.collect(row);
                        }
                    }
                }
            }).name("cogroup at " + i);
        }
        return dataSet3;
    }

    public static DataSet<Row> recoverDataSetWithIdMapping(DataSet<Row> dataSet, DataSet<Tuple2<String, Long>> dataSet2, int[] iArr) {
        DataSet<Row> dataSet3 = dataSet;
        for (final int i : iArr) {
            dataSet3 = dataSet3.coGroup(dataSet2).where(new KeySelector<Row, Long>() { // from class: com.alibaba.alink.operator.batch.graph.utils.IDMappingUtils.6
                public Long getKey(Row row) throws Exception {
                    return (Long) row.getField(i);
                }
            }).equalTo(new int[]{1}).with(new CoGroupFunction<Row, Tuple2<String, Long>, Row>() { // from class: com.alibaba.alink.operator.batch.graph.utils.IDMappingUtils.5
                public void coGroup(Iterable<Row> iterable, Iterable<Tuple2<String, Long>> iterable2, Collector<Row> collector) throws Exception {
                    Iterator<Tuple2<String, Long>> it = iterable2.iterator();
                    if (it.hasNext()) {
                        String str = (String) it.next().f0;
                        for (Row row : iterable) {
                            row.setField(i, str);
                            collector.collect(row);
                        }
                    }
                }
            }).name("cogroup at " + i);
        }
        return dataSet3;
    }

    public static DataSet<Row> mapWalkToStringWithIdMapping(DataSet<long[]> dataSet, DataSet<Tuple2<String, Long>> dataSet2, final int i, final String str) {
        return dataSet.flatMap(new RichFlatMapFunction<long[], Tuple3<Long, Long, Long>>() { // from class: com.alibaba.alink.operator.batch.graph.utils.IDMappingUtils.7
            long cnt;
            int numTasks;
            int taskId;

            public void open(Configuration configuration) throws Exception {
                this.cnt = 0L;
                this.numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
                this.taskId = getRuntimeContext().getIndexOfThisSubtask();
            }

            public void flatMap(long[] jArr, Collector<Tuple3<Long, Long, Long>> collector) throws Exception {
                for (int i2 = 0; i2 < jArr.length; i2++) {
                    collector.collect(Tuple3.of(Long.valueOf((this.cnt * this.numTasks) + this.taskId), Long.valueOf(i2), Long.valueOf(jArr[i2])));
                }
                this.cnt++;
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((long[]) obj, (Collector<Tuple3<Long, Long, Long>>) collector);
            }
        }).name("int2string_map").coGroup(dataSet2).where(new int[]{2}).equalTo(new int[]{1}).with(new CoGroupFunction<Tuple3<Long, Long, Long>, Tuple2<String, Long>, Tuple3<Long, Long, String>>() { // from class: com.alibaba.alink.operator.batch.graph.utils.IDMappingUtils.9
            private static final long serialVersionUID = 1881422618869562798L;

            public void coGroup(Iterable<Tuple3<Long, Long, Long>> iterable, Iterable<Tuple2<String, Long>> iterable2, Collector<Tuple3<Long, Long, String>> collector) throws Exception {
                Iterator<Tuple2<String, Long>> it = iterable2.iterator();
                if (it.hasNext()) {
                    String str2 = (String) it.next().f0;
                    for (Tuple3<Long, Long, Long> tuple3 : iterable) {
                        collector.collect(Tuple3.of(tuple3.f0, tuple3.f1, str2));
                    }
                }
            }
        }).name("int2string_cogroup").groupBy(new int[]{0}).reduceGroup(new GroupReduceFunction<Tuple3<Long, Long, String>, Row>() { // from class: com.alibaba.alink.operator.batch.graph.utils.IDMappingUtils.8
            private static final long serialVersionUID = 4413395610271058332L;
            String[] tmpArray;
            int maxIdx;

            {
                this.tmpArray = new String[i];
            }

            public void reduce(Iterable<Tuple3<Long, Long, String>> iterable, Collector<Row> collector) throws Exception {
                this.maxIdx = -1;
                for (Tuple3<Long, Long, String> tuple3 : iterable) {
                    int intValue = ((Long) tuple3.f1).intValue();
                    this.tmpArray[intValue] = (String) tuple3.f2;
                    this.maxIdx = Math.max(this.maxIdx, intValue);
                }
                StringBuilder sb = new StringBuilder();
                for (int i2 = 0; i2 < this.maxIdx; i2++) {
                    sb.append(this.tmpArray[i2]);
                    sb.append(str);
                }
                sb.append(this.tmpArray[this.maxIdx]);
                Row row = new Row(1);
                row.setField(0, sb.toString());
                collector.collect(row);
            }
        }).name("int2string_reduce");
    }
}
