package com.alibaba.alink.operator.batch.graph;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.graph.MultiSourceShortestPathParams;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.commons.lang3.StringUtils;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.JoinOperator;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.GatherFunction;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.ScatterFunction;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.NullValue;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRPAH_EDGES), @PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRAPH_VERTICES)})
@OutputPorts(values = {@PortSpec(PortType.DATA)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "sourcePointCol", portIndices = {1}), @ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@NameCn("多源最短路径")
@NameEn("Multi Source Shortest Path")
/* loaded from: input_file:com/alibaba/alink/operator/batch/graph/MultiSourceShortestPathBatchOp.class */
public class MultiSourceShortestPathBatchOp extends BatchOperator<MultiSourceShortestPathBatchOp> implements MultiSourceShortestPathParams<MultiSourceShortestPathBatchOp> {
    private static final long serialVersionUID = -1637471953684406867L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/MultiSourceShortestPathBatchOp$MapVertices.class */
    public static class MapVertices implements MapFunction<Vertex<Long, NullValue>, Tuple3<Long, Long, Double>> {
        private static final long serialVersionUID = -6624679629933017172L;

        public Tuple3<Long, Long, Double> map(Vertex<Long, NullValue> vertex) throws Exception {
            return Tuple3.of(-1L, -1L, Double.valueOf(Double.MAX_VALUE));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/MultiSourceShortestPathBatchOp$MessengerSendFunction.class */
    public static final class MessengerSendFunction extends ScatterFunction<Long, Tuple3<Long, Long, Double>, Tuple3<Long, Long, Double>, Double> {
        private static final long serialVersionUID = -2891289370485322356L;

        public void sendMessages(Vertex<Long, Tuple3<Long, Long, Double>> vertex) {
            if (((Long) ((Tuple3) vertex.getValue()).f0).longValue() < 0 || ((Long) ((Tuple3) vertex.getValue()).f1).longValue() < 0) {
                return;
            }
            for (Edge edge : getEdges()) {
                sendMessageTo(edge.getTarget(), new Tuple3(((Tuple3) vertex.getValue()).f0, vertex.getId(), Double.valueOf(((Double) ((Tuple3) vertex.getValue()).f2).doubleValue() + ((Double) edge.getValue()).doubleValue())));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/MultiSourceShortestPathBatchOp$VertexUpdater.class */
    public static final class VertexUpdater extends GatherFunction<Long, Tuple3<Long, Long, Double>, Tuple3<Long, Long, Double>> {
        private static final long serialVersionUID = 612347525612715614L;

        public void updateVertex(Vertex<Long, Tuple3<Long, Long, Double>> vertex, MessageIterator<Tuple3<Long, Long, Double>> messageIterator) {
            double doubleValue = ((Double) ((Tuple3) vertex.f1).f2).doubleValue();
            Tuple3 tuple3 = new Tuple3();
            Iterator it = messageIterator.iterator();
            while (it.hasNext()) {
                Tuple3 tuple32 = (Tuple3) it.next();
                if (((Double) tuple32.f2).doubleValue() < doubleValue) {
                    doubleValue = ((Double) tuple32.f2).doubleValue();
                    tuple3 = tuple32;
                }
            }
            if (doubleValue < ((Double) ((Tuple3) vertex.f1).f2).doubleValue()) {
                setNewVertexValue(tuple3);
            }
        }
    }

    public MultiSourceShortestPathBatchOp(Params params) {
        super(params);
    }

    public MultiSourceShortestPathBatchOp() {
        super(new Params());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public MultiSourceShortestPathBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        checkOpSize(2, batchOperatorArr);
        String edgeSourceCol = getEdgeSourceCol();
        String edgeTargetCol = getEdgeTargetCol();
        String sourcePointCol = getSourcePointCol();
        String[] strArr = {"vertex", "root_node", "node_list", "distance"};
        Integer maxIter = getMaxIter();
        Boolean asUndirectedGraph = getAsUndirectedGraph();
        String edgeWeightCol = getEdgeWeightCol();
        boolean z = edgeWeightCol != null;
        String[] strArr2 = z ? new String[]{edgeSourceCol, edgeTargetCol, edgeWeightCol} : new String[]{edgeSourceCol, edgeTargetCol};
        BatchOperator<?> batchOperator = batchOperatorArr[0];
        TypeInformation<?> findColTypeWithAssertAndHint = TableUtil.findColTypeWithAssertAndHint(batchOperator.getSchema(), edgeSourceCol);
        DataSet<Row> input2json = GraphUtilsWithString.input2json(batchOperator, strArr2, 2, true);
        GraphUtilsWithString graphUtilsWithString = new GraphUtilsWithString(input2json, findColTypeWithAssertAndHint);
        DataSet<Edge<Long, Double>> inputType2longEdge = graphUtilsWithString.inputType2longEdge(input2json, Boolean.valueOf(z));
        JoinOperator with = Graph.fromDataSet(inputType2longEdge, MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment()).mapVertices(new MapVertices()).getVertices().leftOuterJoin(graphUtilsWithString.transformInputVertexWithoutWeight(GraphUtilsWithString.input2json(batchOperatorArr[1], new String[]{sourcePointCol}, 1, true))).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Vertex<Long, Tuple3<Long, Long, Double>>, Vertex<Long, Double>, Vertex<Long, Tuple3<Long, Long, Double>>>() { // from class: com.alibaba.alink.operator.batch.graph.MultiSourceShortestPathBatchOp.1
            private static final long serialVersionUID = 1964647721649366980L;

            public Vertex<Long, Tuple3<Long, Long, Double>> join(Vertex<Long, Tuple3<Long, Long, Double>> vertex, Vertex<Long, Double> vertex2) throws Exception {
                if (vertex2 == null) {
                    return vertex;
                }
                Vertex<Long, Tuple3<Long, Long, Double>> vertex3 = new Vertex<>();
                vertex3.setId(vertex.f0);
                vertex3.setValue(Tuple3.of(vertex.f0, vertex.f0, Double.valueOf(Criteria.INVALID_GAIN)));
                return vertex3;
            }
        });
        DataSet<Row> long2StringSSSP = graphUtilsWithString.long2StringSSSP((asUndirectedGraph.booleanValue() ? Graph.fromDataSet(with, inputType2longEdge, MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment()).getUndirected() : Graph.fromDataSet(with, inputType2longEdge, MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment())).runScatterGatherIteration(new MessengerSendFunction(), new VertexUpdater(), maxIter.intValue()).getVertices().map(new MapFunction<Vertex<Long, Tuple3<Long, Long, Double>>, Tuple4<Long, Long, Long, Double>>() { // from class: com.alibaba.alink.operator.batch.graph.MultiSourceShortestPathBatchOp.2
            public Tuple4<Long, Long, Long, Double> map(Vertex<Long, Tuple3<Long, Long, Double>> vertex) throws Exception {
                return Tuple4.of(vertex.f0, ((Tuple3) vertex.f1).f0, ((Tuple3) vertex.f1).f1, ((Tuple3) vertex.f1).f2);
            }
        }));
        setOutput(long2StringSSSP.filter(new FilterFunction<Row>() { // from class: com.alibaba.alink.operator.batch.graph.MultiSourceShortestPathBatchOp.7
            public boolean filter(Row row) throws Exception {
                return row.getField(1) != null;
            }
        }).groupBy(new KeySelector<Row, String>() { // from class: com.alibaba.alink.operator.batch.graph.MultiSourceShortestPathBatchOp.6
            public String getKey(Row row) throws Exception {
                return String.valueOf(row.getField(1));
            }
        }).reduceGroup(new GroupReduceFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.graph.MultiSourceShortestPathBatchOp.5
            public void reduce(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
                HashMap hashMap = new HashMap();
                HashMap hashMap2 = new HashMap();
                HashSet hashSet = new HashSet();
                Object obj = null;
                for (Row row : iterable) {
                    obj = row.getField(1);
                    Object field = row.getField(0);
                    Object field2 = row.getField(2);
                    hashSet.add(field2);
                    hashMap.put(field, field2);
                    hashMap2.put(field, (Double) row.getField(3));
                }
                ArrayList arrayList = new ArrayList();
                hashMap.forEach((obj2, obj3) -> {
                    if (hashSet.contains(obj2)) {
                        return;
                    }
                    arrayList.add(obj2);
                });
                HashSet hashSet2 = new HashSet();
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    Object next = it.next();
                    ArrayList arrayList2 = new ArrayList();
                    arrayList2.add(next);
                    Object obj4 = hashMap.get(next);
                    arrayList2.add(obj4);
                    while (!obj4.equals(obj)) {
                        obj4 = hashMap.get(obj4);
                        arrayList2.add(obj4);
                    }
                    ArrayList arrayList3 = new ArrayList();
                    arrayList2.forEach(obj5 -> {
                        arrayList3.add(String.valueOf(obj5));
                    });
                    for (int i = 0; i < arrayList2.size(); i++) {
                        Object obj6 = arrayList2.get(i);
                        if (hashSet2.contains(obj6)) {
                            break;
                        }
                        Row row2 = new Row(4);
                        row2.setField(0, obj6);
                        row2.setField(1, obj);
                        row2.setField(2, StringUtils.join(arrayList3.subList(i, arrayList3.size()), ","));
                        row2.setField(3, hashMap2.get(obj6));
                        collector.collect(row2);
                        hashSet2.add(obj6);
                    }
                }
                if (hashSet2.contains(obj)) {
                    return;
                }
                Row row3 = new Row(4);
                row3.setField(0, obj);
                row3.setField(1, obj);
                row3.setField(2, String.valueOf(obj));
                row3.setField(3, hashMap2.get(obj));
                collector.collect(row3);
            }
        }).union(long2StringSSSP.filter(new FilterFunction<Row>() { // from class: com.alibaba.alink.operator.batch.graph.MultiSourceShortestPathBatchOp.4
            public boolean filter(Row row) throws Exception {
                return row.getField(1) == null;
            }
        }).map(new MapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.graph.MultiSourceShortestPathBatchOp.3
            public Row map(Row row) throws Exception {
                row.setField(2, "");
                row.setField(3, Double.valueOf(-1.0d));
                return row;
            }
        })), strArr, new TypeInformation[]{findColTypeWithAssertAndHint, findColTypeWithAssertAndHint, Types.STRING, Types.DOUBLE});
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ MultiSourceShortestPathBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
