package com.alibaba.alink.operator.batch.graph;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.comqueue.IterTaskObjKeeper;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.graph.storage.GraphEdge;
import com.alibaba.alink.operator.batch.graph.storage.HomoGraphEngine;
import com.alibaba.alink.operator.batch.graph.utils.ComputeGraphStatistics;
import com.alibaba.alink.operator.batch.graph.utils.ConstructHomoEdge;
import com.alibaba.alink.operator.batch.graph.utils.EndWritingRandomWalks;
import com.alibaba.alink.operator.batch.graph.utils.GraphPartition;
import com.alibaba.alink.operator.batch.graph.utils.GraphStatistics;
import com.alibaba.alink.operator.batch.graph.utils.HandleReceivedMessage;
import com.alibaba.alink.operator.batch.graph.utils.IDMappingUtils;
import com.alibaba.alink.operator.batch.graph.utils.LongArrayToRow;
import com.alibaba.alink.operator.batch.graph.utils.ParseGraphData;
import com.alibaba.alink.operator.batch.graph.utils.RandomWalkMemoryBuffer;
import com.alibaba.alink.operator.batch.graph.utils.ReadFromBufferAndRemoveStaticObject;
import com.alibaba.alink.operator.batch.graph.utils.RecvRequestKeySelector;
import com.alibaba.alink.operator.batch.graph.utils.SendRequestKeySelector;
import com.alibaba.alink.operator.batch.graph.walkpath.RandomWalkPathEngine;
import com.alibaba.alink.params.nlp.walk.RandomWalkParams;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.NumberSequenceIterator;

@InputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.GRAPH)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "sourceCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.INT_LONG_TYPES, TypeCollections.STRING_TYPES}), @ParamSelectColumnSpec(name = "targetCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.INT_LONG_TYPES, TypeCollections.STRING_TYPES}), @ParamSelectColumnSpec(name = "weightCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
@NameCn("随机游走")
@NameEn("Random Walk")
/* loaded from: input_file:com/alibaba/alink/operator/batch/graph/RandomWalkBatchOp.class */
public final class RandomWalkBatchOp extends BatchOperator<RandomWalkBatchOp> implements RandomWalkParams<RandomWalkBatchOp> {
    private static final long serialVersionUID = 3726910334434343013L;
    public static final String PATH_COL_NAME = "path";
    public static final String GRAPH_STATISTICS = "graphStatistics";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/RandomWalkBatchOp$CacheGraphAndRandomWalk.class */
    public static class CacheGraphAndRandomWalk extends RichMapPartitionFunction<GraphEdge, RandomWalkCommunicationUnit> {
        long graphStorageHandler;
        long randomWalkStorageHandler;
        long walkWriteBufferHandler;
        long walkReadBufferHandler;
        int numWalkPerVertex;
        int walkLen;
        boolean isWeighted;
        String samplingMethod;
        GraphPartition.GraphPartitionFunction graphPartitionFunction;

        public CacheGraphAndRandomWalk(long j, long j2, long j3, long j4, int i, int i2, boolean z, String str, GraphPartition.GraphPartitionFunction graphPartitionFunction) {
            this.graphStorageHandler = j;
            this.randomWalkStorageHandler = j2;
            this.walkWriteBufferHandler = j3;
            this.walkReadBufferHandler = j4;
            this.numWalkPerVertex = i;
            this.walkLen = i2;
            this.isWeighted = z;
            this.samplingMethod = str;
            this.graphPartitionFunction = graphPartitionFunction;
        }

        public void mapPartition(Iterable<GraphEdge> iterable, Collector<RandomWalkCommunicationUnit> collector) throws Exception {
            List broadcastVariable = getRuntimeContext().getBroadcastVariable("graphStatistics");
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            if (superstepNumber != 1) {
                if (superstepNumber == 2) {
                    List<RandomWalkCommunicationUnit> broadcastVariable2 = getRuntimeContext().getBroadcastVariable("loop");
                    HashMap hashMap = new HashMap(broadcastVariable2.size());
                    for (RandomWalkCommunicationUnit randomWalkCommunicationUnit : broadcastVariable2) {
                        hashMap.put(Integer.valueOf(randomWalkCommunicationUnit.getSrcPartitionId()), Integer.valueOf(randomWalkCommunicationUnit.getDstPartitionId()));
                    }
                    HomoGraphEngine homoGraphEngine = (HomoGraphEngine) IterTaskObjKeeper.get(this.graphStorageHandler, indexOfThisSubtask);
                    AkPreconditions.checkNotNull(homoGraphEngine, "homoGraphEngine is null");
                    homoGraphEngine.setLogicalWorkerIdToPhysicalWorkerId(hashMap);
                    return;
                }
                return;
            }
            GraphStatistics graphStatistics = null;
            Iterator it = broadcastVariable.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                GraphStatistics graphStatistics2 = (GraphStatistics) it.next();
                if (graphStatistics2.getPartitionId() == indexOfThisSubtask) {
                    graphStatistics = graphStatistics2;
                    break;
                }
            }
            HomoGraphEngine homoGraphEngine2 = new HomoGraphEngine(iterable, graphStatistics.getVertexNum(), graphStatistics.getEdgeNum(), this.isWeighted, this.samplingMethod.equalsIgnoreCase("ALIAS"));
            IterTaskObjKeeper.put(this.graphStorageHandler, indexOfThisSubtask, homoGraphEngine2);
            IterTaskObjKeeper.put(this.randomWalkStorageHandler, indexOfThisSubtask, new RandomWalkPathEngine(Math.min(16777216 / (this.numWalkPerVertex * this.walkLen), graphStatistics.getVertexNum()), this.numWalkPerVertex, this.walkLen, homoGraphEngine2.getAllSrcVertices()));
            RandomWalkMemoryBuffer randomWalkMemoryBuffer = new RandomWalkMemoryBuffer(67108864 / ((this.walkLen * 8) + 16));
            IterTaskObjKeeper.put(this.walkWriteBufferHandler, indexOfThisSubtask, randomWalkMemoryBuffer);
            IterTaskObjKeeper.put(this.walkReadBufferHandler, indexOfThisSubtask, randomWalkMemoryBuffer);
            if (homoGraphEngine2.getAllSrcVertices().hasNext()) {
                collector.collect(new RandomWalkCommunicationUnit(this.graphPartitionFunction.apply(homoGraphEngine2.getAllSrcVertices().next().longValue(), numberOfParallelSubtasks), indexOfThisSubtask, null, null));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/RandomWalkBatchOp$DoRemoteProcessing.class */
    public static class DoRemoteProcessing extends RichMapPartitionFunction<RandomWalkCommunicationUnit, RandomWalkCommunicationUnit> {
        long graphStorageHandler;

        public DoRemoteProcessing(long j) {
            this.graphStorageHandler = j;
        }

        public void mapPartition(Iterable<RandomWalkCommunicationUnit> iterable, Collector<RandomWalkCommunicationUnit> collector) throws Exception {
            if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
                Iterator<RandomWalkCommunicationUnit> it = iterable.iterator();
                while (it.hasNext()) {
                    collector.collect(it.next());
                }
                return;
            }
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            HomoGraphEngine homoGraphEngine = (HomoGraphEngine) IterTaskObjKeeper.get(this.graphStorageHandler, indexOfThisSubtask);
            AkPreconditions.checkNotNull(homoGraphEngine, "homoGraphEngine is null");
            for (RandomWalkCommunicationUnit randomWalkCommunicationUnit : iterable) {
                AkPreconditions.checkState(randomWalkCommunicationUnit.getDstPartitionId() == indexOfThisSubtask, "The target task id is incorrect. It should be " + randomWalkCommunicationUnit.getDstPartitionId() + ", but it is " + indexOfThisSubtask);
                Long[] requestedVertexIds = randomWalkCommunicationUnit.getRequestedVertexIds();
                for (int i = 0; i < requestedVertexIds.length; i++) {
                    if (homoGraphEngine.containsVertex(requestedVertexIds[i].longValue())) {
                        requestedVertexIds[i] = Long.valueOf(homoGraphEngine.sampleOneNeighbor(requestedVertexIds[i].longValue()));
                    } else {
                        requestedVertexIds[i] = -1L;
                    }
                }
                collector.collect(randomWalkCommunicationUnit);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/RandomWalkBatchOp$GetMessageToSend.class */
    public static class GetMessageToSend extends RichMapPartitionFunction<RandomWalkCommunicationUnit, RandomWalkCommunicationUnit> {
        long graphStorageHandler;
        long randomWalkStorageHandler;
        long walkBufferHandler;
        GraphPartition.GraphPartitionFunction graphPartitionFunction;

        public GetMessageToSend(long j, long j2, long j3, GraphPartition.GraphPartitionFunction graphPartitionFunction) {
            this.graphStorageHandler = j;
            this.randomWalkStorageHandler = j2;
            this.walkBufferHandler = j3;
            this.graphPartitionFunction = graphPartitionFunction;
        }

        public void mapPartition(Iterable<RandomWalkCommunicationUnit> iterable, Collector<RandomWalkCommunicationUnit> collector) throws Exception {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
                Iterator<RandomWalkCommunicationUnit> it = iterable.iterator();
                while (it.hasNext()) {
                    collector.collect(it.next());
                }
                return;
            }
            for (RandomWalkCommunicationUnit randomWalkCommunicationUnit : iterable) {
            }
            HomoGraphEngine homoGraphEngine = (HomoGraphEngine) IterTaskObjKeeper.get(this.graphStorageHandler, indexOfThisSubtask);
            RandomWalkPathEngine randomWalkPathEngine = (RandomWalkPathEngine) IterTaskObjKeeper.get(this.randomWalkStorageHandler, indexOfThisSubtask);
            RandomWalkMemoryBuffer randomWalkMemoryBuffer = (RandomWalkMemoryBuffer) IterTaskObjKeeper.get(this.walkBufferHandler, indexOfThisSubtask);
            AkPreconditions.checkNotNull(homoGraphEngine, "homoGraphEngine is null");
            AkPreconditions.checkNotNull(randomWalkPathEngine, "randomWalkPathEngine is null");
            AkPreconditions.checkNotNull(randomWalkMemoryBuffer, "randomWalkMemoryBuffer is null");
            long[] nextBatchOfVerticesToSampleFrom = randomWalkPathEngine.getNextBatchOfVerticesToSampleFrom();
            for (int i = 0; i < nextBatchOfVerticesToSampleFrom.length; i++) {
                if (randomWalkPathEngine.canOutput(i)) {
                    long[] oneWalkAndAddNewWalk = randomWalkPathEngine.getOneWalkAndAddNewWalk(i);
                    nextBatchOfVerticesToSampleFrom[i] = randomWalkPathEngine.getNextVertexToSampleFrom(i);
                    randomWalkMemoryBuffer.writeOneWalk(oneWalkAndAddNewWalk);
                }
                long j = nextBatchOfVerticesToSampleFrom[i];
                while (homoGraphEngine.containsVertex(j)) {
                    j = homoGraphEngine.sampleOneNeighbor(j);
                    randomWalkPathEngine.updatePath(i, j);
                    if (randomWalkPathEngine.canOutput(i)) {
                        randomWalkMemoryBuffer.writeOneWalk(randomWalkPathEngine.getOneWalkAndAddNewWalk(i));
                        j = randomWalkPathEngine.getNextVertexToSampleFrom(i);
                    }
                }
                nextBatchOfVerticesToSampleFrom[i] = j;
            }
            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            ArrayList[] arrayListArr = new ArrayList[numberOfParallelSubtasks];
            ArrayList[] arrayListArr2 = new ArrayList[numberOfParallelSubtasks];
            for (int i2 = 0; i2 < numberOfParallelSubtasks; i2++) {
                arrayListArr[i2] = new ArrayList();
                arrayListArr2[i2] = new ArrayList();
            }
            for (int i3 = 0; i3 < nextBatchOfVerticesToSampleFrom.length; i3++) {
                if (nextBatchOfVerticesToSampleFrom[i3] != -1) {
                    int physicalWorkerIdByLogicalWorkerId = homoGraphEngine.getPhysicalWorkerIdByLogicalWorkerId(this.graphPartitionFunction.apply(nextBatchOfVerticesToSampleFrom[i3], numberOfParallelSubtasks));
                    arrayListArr[physicalWorkerIdByLogicalWorkerId].add(Long.valueOf(nextBatchOfVerticesToSampleFrom[i3]));
                    arrayListArr2[physicalWorkerIdByLogicalWorkerId].add(Integer.valueOf(i3));
                }
            }
            for (int i4 = 0; i4 < numberOfParallelSubtasks; i4++) {
                if (arrayListArr2[i4].size() != 0) {
                    collector.collect(new RandomWalkCommunicationUnit(indexOfThisSubtask, i4, (Integer[]) arrayListArr2[i4].toArray(new Integer[0]), (Long[]) arrayListArr[i4].toArray(new Long[0])));
                }
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/RandomWalkBatchOp$RandomWalkCommunicationUnit.class */
    public static class RandomWalkCommunicationUnit implements Serializable {
        int srcPartitionId;
        int dstPartitionId;
        Long[] requestedVertexIds;
        Integer[] walkIds;

        public RandomWalkCommunicationUnit(int i, int i2, Integer[] numArr, Long[] lArr) {
            this(i, i2, numArr, lArr, null);
        }

        public RandomWalkCommunicationUnit(int i, int i2, Integer[] numArr, Long[] lArr, Character[] chArr) {
            this.srcPartitionId = i;
            this.dstPartitionId = i2;
            this.walkIds = numArr;
            this.requestedVertexIds = lArr;
        }

        public int getDstPartitionId() {
            return this.dstPartitionId;
        }

        public int getSrcPartitionId() {
            return this.srcPartitionId;
        }

        public Integer[] getWalkIds() {
            return this.walkIds;
        }

        public Long[] getRequestedVertexIds() {
            return this.requestedVertexIds;
        }
    }

    public RandomWalkBatchOp(Params params) {
        super(params);
    }

    public RandomWalkBatchOp() {
        this(new Params());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public RandomWalkBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        DataSet<Row> dataSet;
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        int intValue = getWalkNum().intValue();
        int intValue2 = getWalkLength().intValue();
        String delimiter = getDelimiter();
        String weightCol = getWeightCol();
        Boolean isToUndigraph = getIsToUndigraph();
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), getSourceCol());
        int findColIndexWithAssertAndHint2 = TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), getTargetCol());
        int findColIndexWithAssertAndHint3 = weightCol == null ? -1 : TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), weightCol);
        boolean z = weightCol != null;
        BasicTypeInfo basicTypeInfo = checkAndGetFirst.getColTypes()[findColIndexWithAssertAndHint];
        BasicTypeInfo basicTypeInfo2 = checkAndGetFirst.getColTypes()[findColIndexWithAssertAndHint2];
        DataSet<Tuple2<String, Long>> dataSet2 = null;
        boolean z2 = ((basicTypeInfo == BasicTypeInfo.LONG_TYPE_INFO && basicTypeInfo2 == BasicTypeInfo.LONG_TYPE_INFO) || (basicTypeInfo == BasicTypeInfo.INT_TYPE_INFO && basicTypeInfo2 == BasicTypeInfo.INT_TYPE_INFO)) ? false : true;
        if (z2) {
            dataSet2 = IDMappingUtils.computeIdMapping(checkAndGetFirst.getDataSet(), new int[]{findColIndexWithAssertAndHint, findColIndexWithAssertAndHint2});
            dataSet = IDMappingUtils.mapDataSetWithIdMapping(checkAndGetFirst.getDataSet(), dataSet2, new int[]{findColIndexWithAssertAndHint, findColIndexWithAssertAndHint2});
        } else {
            dataSet = checkAndGetFirst.getDataSet();
        }
        GraphPartition.GraphPartitionHashFunction graphPartitionHashFunction = new GraphPartition.GraphPartitionHashFunction();
        Operator name = dataSet.flatMap(new ParseGraphData(findColIndexWithAssertAndHint, findColIndexWithAssertAndHint2, findColIndexWithAssertAndHint3, isToUndigraph.booleanValue())).partitionCustom(new GraphPartition.GraphPartitioner(graphPartitionHashFunction), 0).sortPartition(0, Order.ASCENDING).sortPartition(1, Order.ASCENDING).map(new ConstructHomoEdge()).name("parsedAndPartitionedAndSortedGraph");
        Operator name2 = name.mapPartition(new ComputeGraphStatistics()).name("graphStatistics");
        long newHandle = IterTaskObjKeeper.getNewHandle();
        long newHandle2 = IterTaskObjKeeper.getNewHandle();
        long newHandle3 = IterTaskObjKeeper.getNewHandle();
        long newHandle4 = IterTaskObjKeeper.getNewHandle();
        ExecutionEnvironment executionEnvironment = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment();
        IterativeDataSet name3 = executionEnvironment.fromParallelCollection(new NumberSequenceIterator(1L, executionEnvironment.getParallelism()), BasicTypeInfo.LONG_TYPE_INFO).map(new MapFunction<Long, RandomWalkCommunicationUnit>() { // from class: com.alibaba.alink.operator.batch.graph.RandomWalkBatchOp.1
            public RandomWalkCommunicationUnit map(Long l) throws Exception {
                return new RandomWalkCommunicationUnit(1, 1, null, null);
            }
        }).name("initData").iterate(Integer.MAX_VALUE).name("loop");
        Operator name4 = name.mapPartition(new CacheGraphAndRandomWalk(newHandle, newHandle2, newHandle3, newHandle4, intValue, intValue2, z, getSamplingMethod(), graphPartitionHashFunction)).withBroadcastSet(name2, "graphStatistics").withBroadcastSet(name3, "loop").name("cacheGraphAndRandomWalk").mapPartition(new GetMessageToSend(newHandle, newHandle2, newHandle3, graphPartitionHashFunction)).name("sendCommunicationUnit");
        DataSet closeWith = name3.closeWith(name4.partitionCustom(new GraphPartition.GraphPartitioner(graphPartitionHashFunction), new SendRequestKeySelector()).mapPartition(new DoRemoteProcessing(newHandle)).partitionCustom(new GraphPartition.GraphPartitioner(graphPartitionHashFunction), new RecvRequestKeySelector()).name("recvCommunicationUnit").mapPartition(new HandleReceivedMessage(newHandle2)).name("finishedOneStep"), name4.map(new MapFunction<RandomWalkCommunicationUnit, Object>() { // from class: com.alibaba.alink.operator.batch.graph.RandomWalkBatchOp.2
            public Object map(RandomWalkCommunicationUnit randomWalkCommunicationUnit) throws Exception {
                return new Object();
            }
        }).name("termination"));
        Operator name5 = executionEnvironment.fromParallelCollection(new NumberSequenceIterator(0L, executionEnvironment.getParallelism()), BasicTypeInfo.LONG_TYPE_INFO).mapPartition(new ReadFromBufferAndRemoveStaticObject(newHandle, newHandle2, newHandle4, delimiter)).name("memoryOut").union(closeWith.mapPartition(new EndWritingRandomWalks(newHandle3)).name("emptyOut").withBroadcastSet(closeWith, "output")).name("mergedOutput");
        setOutput(z2 ? IDMappingUtils.mapWalkToStringWithIdMapping(name5, dataSet2, intValue2, delimiter) : name5.map(new LongArrayToRow(delimiter)).name("finalOutput"), new TableSchema(new String[]{"path"}, new TypeInformation[]{Types.STRING}));
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ RandomWalkBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
