package com.alibaba.alink.operator.batch.graph;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.graph.SingleSourceShortestPathParams;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.JoinOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.gsa.ApplyFunction;
import org.apache.flink.graph.gsa.GatherFunction;
import org.apache.flink.graph.gsa.Neighbor;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.NullValue;
import org.apache.flink.types.Row;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRPAH_EDGES)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@NameCn("单源最短路径")
@NameEn("Single Source Shortest Path")
/* loaded from: input_file:com/alibaba/alink/operator/batch/graph/SingleSourceShortestPathBatchOp.class */
public class SingleSourceShortestPathBatchOp extends BatchOperator<SingleSourceShortestPathBatchOp> implements SingleSourceShortestPathParams<SingleSourceShortestPathBatchOp> {
    private static final long serialVersionUID = -1637471953684406867L;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/SingleSourceShortestPathBatchOp$CalculateDistances.class */
    public static final class CalculateDistances extends GatherFunction<Double, Double, Double> {
        private static final long serialVersionUID = 8666543088654403877L;

        private CalculateDistances() {
        }

        public Double gather(Neighbor<Double, Double> neighbor) {
            return Double.valueOf(((Double) neighbor.getNeighborValue()).doubleValue() + ((Double) neighbor.getEdgeValue()).doubleValue());
        }

        /* renamed from: gather, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m264gather(Neighbor neighbor) {
            return gather((Neighbor<Double, Double>) neighbor);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/SingleSourceShortestPathBatchOp$ChooseMinDistance.class */
    public static final class ChooseMinDistance extends SumFunction<Double, Double, Double> {
        private static final long serialVersionUID = 7176693441223938280L;

        private ChooseMinDistance() {
        }

        public Double sum(Double d, Double d2) {
            return Double.valueOf(Math.min(d.doubleValue(), d2.doubleValue()));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/SingleSourceShortestPathBatchOp$MapVertices.class */
    public static class MapVertices implements MapFunction<Vertex<Long, NullValue>, Double> {
        private static final long serialVersionUID = -6624679629933017172L;

        public Double map(Vertex<Long, NullValue> vertex) throws Exception {
            return Double.valueOf(Double.POSITIVE_INFINITY);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/SingleSourceShortestPathBatchOp$UpdateDistance.class */
    public static final class UpdateDistance extends ApplyFunction<Long, Double, Double> {
        private static final long serialVersionUID = 7753801433080130491L;

        private UpdateDistance() {
        }

        public void apply(Double d, Double d2) {
            if (d.doubleValue() < d2.doubleValue()) {
                setResult(d);
            }
        }
    }

    public SingleSourceShortestPathBatchOp(Params params) {
        super(params);
    }

    public SingleSourceShortestPathBatchOp() {
        super(new Params());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public SingleSourceShortestPathBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String edgeSourceCol = getEdgeSourceCol();
        String edgeTargetCol = getEdgeTargetCol();
        String sourcePoint = getSourcePoint();
        String[] strArr = {"vertex", "distance"};
        Integer maxIter = getMaxIter();
        Boolean asUndirectedGraph = getAsUndirectedGraph();
        boolean z = getEdgeWeightCol() != null;
        String[] strArr2 = z ? new String[]{edgeSourceCol, edgeTargetCol, getEdgeWeightCol()} : new String[]{edgeSourceCol, edgeTargetCol};
        TypeInformation<?> typeInformation = checkAndGetFirst.getColTypes()[TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), edgeSourceCol)];
        DataSet<Row> input2json = GraphUtilsWithString.input2json(checkAndGetFirst, strArr2, 2, true);
        GraphUtilsWithString graphUtilsWithString = new GraphUtilsWithString(input2json, typeInformation);
        DataSet<Edge<Long, Double>> inputType2longEdge = graphUtilsWithString.inputType2longEdge(input2json, Boolean.valueOf(z));
        JoinOperator with = Graph.fromDataSet(inputType2longEdge, MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment()).mapVertices(new MapVertices()).getVertices().leftOuterJoin(graphUtilsWithString.string2longSource(sourcePoint, checkAndGetFirst.getMLEnvironmentId().longValue(), typeInformation)).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Vertex<Long, Double>, Tuple1<Long>, Vertex<Long, Double>>() { // from class: com.alibaba.alink.operator.batch.graph.SingleSourceShortestPathBatchOp.1
            private static final long serialVersionUID = 1964647721649366980L;

            public Vertex<Long, Double> join(Vertex<Long, Double> vertex, Tuple1<Long> tuple1) throws Exception {
                if (tuple1 != null) {
                    vertex.f1 = Double.valueOf(Criteria.INVALID_GAIN);
                }
                return vertex;
            }
        });
        setOutput(graphUtilsWithString.double2outputTypeVertex((asUndirectedGraph.booleanValue() ? Graph.fromDataSet(with, inputType2longEdge, MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment()).getUndirected() : Graph.fromDataSet(with, inputType2longEdge, MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment())).runGatherSumApplyIteration(new CalculateDistances(), new ChooseMinDistance(), new UpdateDistance(), maxIter.intValue()).getVertices(), Types.DOUBLE), strArr, new TypeInformation[]{typeInformation, Types.DOUBLE});
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ SingleSourceShortestPathBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
