单源最短距离

单源最短距离是指给定图中一个源点,计算源点到其它所有节点的最短距离。Dijkstra算法是求解有向图中单源最短距离SSSP(Single Source Shortest Path)的经典算法。

算法原理

Dijkstra算法是通过去更新最短距离值,每个维护到源点的当前最短距离值,当这个值发生变化时,将新值加上权值,发送消息通知其邻接点。下一轮迭代时,邻接点根据收到的消息,更新其当前最短距离值,当所有的当前最短距离值不再变化时,迭代结束。

  • 初始化:源点ss自身的距离为0(d[s]=0),其他点us的距离为无穷(d[u]=∞)。

  • 迭代:如果存在一条从uv的边,则从sv的最短距离更新为d[v]=min(d[v], d[u]+weight(u, v)),直到所有的点到s的距离不再发生变化时,迭代结束。

说明

对一个有权重的有向图G=(V,E),从一个源点s到汇点v有很多路径,其中边权和最小的路径,称为从sv的最短距离。

由算法基本原理可以看出,此算法非常适合用MaxCompute Graph程序进行求解。

使用场景

图类型通常分为有向图和无向图两种,MaxCompute均支持。基于源数据的分布,构造有向图和无向图时的路径通路会存在差异,求解SSSP时会产生不同的结果。MaxCompute Graph以有向图为基础数据模型,框架内会以有向图的模型参与计算。

代码示例

以下代码基于不同的场景,提供不同的代码示例。

  • 有向图

    • 定义类BaseLoadingVertexResolver,此异常类会在SSSP类中被引用。

      import com.aliyun.odps.graph.Edge;
      import com.aliyun.odps.graph.LoadingVertexResolver;
      import com.aliyun.odps.graph.Vertex;
      import com.aliyun.odps.graph.VertexChanges;
      import com.aliyun.odps.io.Writable;
      import com.aliyun.odps.io.WritableComparable;
      
      import java.io.IOException;
      import java.util.HashSet;
      import java.util.Iterator;
      import java.util.List;
      import java.util.Set;
      
      @SuppressWarnings("rawtypes")
      public class BaseLoadingVertexResolver<I extends WritableComparable, V extends Writable, E extends Writable, M extends Writable>
              extends LoadingVertexResolver<I, V, E, M> {
        @Override
        public Vertex<I, V, E, M> resolve(I vertexId, VertexChanges<I, V, E, M> vertexChanges) throws IOException {
      
          Vertex<I, V, E, M> vertex = addVertexIfDesired(vertexId, vertexChanges);
      
          if (vertex != null) {
            addEdges(vertex, vertexChanges);
          } else {
            System.err.println("Ignore all addEdgeRequests for vertex#" + vertexId);
          }
          return vertex;
        }
      
        protected Vertex<I, V, E, M> addVertexIfDesired(
                I vertexId,
                VertexChanges<I, V, E, M> vertexChanges) {
          Vertex<I, V, E, M> vertex = null;
          if (hasVertexAdditions(vertexChanges)) {
            vertex = vertexChanges.getAddedVertexList().get(0);
          }
      
          return vertex;
        }
      
        protected void addEdges(Vertex<I, V, E, M> vertex,
                                VertexChanges<I, V, E, M> vertexChanges) throws IOException {
          Set<I> destVertexId = new HashSet<I>();
          if (vertex.hasEdges()) {
            List<Edge<I, E>> edgeList = vertex.getEdges();
            for (Iterator<Edge<I, E>> edges = edgeList.iterator(); edges.hasNext(); ) {
              Edge<I, E> edge = edges.next();
              if (destVertexId.contains(edge.getDestVertexId())) {
                edges.remove();
              } else {
                destVertexId.add(edge.getDestVertexId());
              }
            }
          }
      
          for (Vertex<I, V, E, M> vertex1 : vertexChanges.getAddedVertexList()) {
            if (vertex1.hasEdges()) {
              List<Edge<I, E>> edgeList = vertex1.getEdges();
              for (Edge<I, E> edge : edgeList) {
                if (destVertexId.contains(edge.getDestVertexId())) continue;
                destVertexId.add(edge.getDestVertexId());
                vertex.addEdge(edge.getDestVertexId(), edge.getValue());
              }
            }
          }
        }
      
        protected boolean hasVertexAdditions(VertexChanges<I, V, E, M> changes) {
          return changes != null && changes.getAddedVertexList() != null
                  && !changes.getAddedVertexList().isEmpty();
        }
      }

      代码说明:

      • 15行:定义BaseLoadingVertexResolver。用于处理有向图数据在加载过程中所遇到的冲突。

      • 18行:resolve为处理冲突的具体方法。例如当前的某一顶点进行了两次添加的过程(即进行了两次addVertexRequest操作),这种行为便称为冲突加载,需要人为处理冲突之后再参与计算。

    • 定义类SSSP

      import java.io.IOException;
      
      import com.aliyun.odps.graph.Combiner;
      import com.aliyun.odps.graph.ComputeContext;
      import com.aliyun.odps.graph.Edge;
      import com.aliyun.odps.graph.GraphJob;
      import com.aliyun.odps.graph.GraphLoader;
      import com.aliyun.odps.graph.MutationContext;
      import com.aliyun.odps.graph.Vertex;
      import com.aliyun.odps.graph.WorkerContext;
      import com.aliyun.odps.io.WritableRecord;
      import com.aliyun.odps.io.LongWritable;
      import com.aliyun.odps.data.TableInfo;
      
      public class SSSP {
        public static final String START_VERTEX = "sssp.start.vertex.id";
      
        public static class SSSPVertex extends
                Vertex<LongWritable, LongWritable, LongWritable, LongWritable> {
          private static long startVertexId = -1;
      
          public SSSPVertex() {
            this.setValue(new LongWritable(Long.MAX_VALUE));
          }
      
          public boolean isStartVertex(
                  ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context) {
            if (startVertexId == -1) {
              String s = context.getConfiguration().get(START_VERTEX);
              startVertexId = Long.parseLong(s);
            }
            return getId().get() == startVertexId;
          }
      
          @Override
          public void compute(
                  ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context,
                  Iterable<LongWritable> messages) throws IOException {
            long minDist = isStartVertex(context) ? 0 : Long.MAX_VALUE;
            for (LongWritable msg : messages) {
              if (msg.get() < minDist) {
                minDist = msg.get();
              }
            }
            if (minDist < this.getValue().get()) {
              this.setValue(new LongWritable(minDist));
              if (hasEdges()) {
                for (Edge<LongWritable, LongWritable> e : this.getEdges()) {
                  context.sendMessage(e.getDestVertexId(), new LongWritable(minDist + e.getValue().get()));
                }
              }
            } else {
              voteToHalt();
            }
          }
      
          @Override
          public void cleanup(
                  WorkerContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
                  throws IOException {
            context.write(getId(), getValue());
          }
      
          @Override
          public String toString() {
            return "Vertex(id=" + this.getId() + ",value=" + this.getValue() + ",#edges=" + this.getEdges() + ")";
          }
        }
      
        public static class SSSPGraphLoader extends
                GraphLoader<LongWritable, LongWritable, LongWritable, LongWritable> {
          @Override
          public void load(
                  LongWritable recordNum,
                  WritableRecord record,
                  MutationContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
                  throws IOException {
            SSSPVertex vertex = new SSSPVertex();
            vertex.setId((LongWritable) record.get(0));
            String[] edges = record.get(1).toString().split(",");
            for (String edge : edges) {
              String[] ss = edge.split(":");
              vertex.addEdge(new LongWritable(Long.parseLong(ss[0])), new LongWritable(Long.parseLong(ss[1])));
            }
            context.addVertexRequest(vertex);
          }
        }
      
        public static class MinLongCombiner extends
                Combiner<LongWritable, LongWritable> {
          @Override
          public void combine(LongWritable vertexId, LongWritable combinedMessage,
                              LongWritable messageToCombine) throws IOException {
            if (combinedMessage.get() > messageToCombine.get()) {
              combinedMessage.set(messageToCombine.get());
            }
          }
        }
      
        public static void main(String[] args) throws IOException {
          if (args.length < 3) {
            System.out.println("Usage: <startnode> <input> <output>");
            System.exit(-1);
          }
          GraphJob job = new GraphJob();
          job.setGraphLoaderClass(SSSPGraphLoader.class);
          job.setVertexClass(SSSPVertex.class);
          job.setCombinerClass(MinLongCombiner.class);
          job.setLoadingVertexResolver(BaseLoadingVertexResolver.class);
          job.set(START_VERTEX, args[0]);
          job.addInput(TableInfo.builder().tableName(args[1]).build());
          job.addOutput(TableInfo.builder().tableName(args[2]).build());
          long startTime = System.currentTimeMillis();
          job.run();
          System.out.println("Job Finished in "
                  + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
        }
      }
      
                                  

      代码说明:

      • 19行:定义SSSPVertex。其中:

        • 点值表示该顶点到源点startVertexId的最短距离。

        • compute()方法使用迭代公式d[v]=min(d[v], d[u]+weight(u, v))计算最短距离值并更新至当前点值。

        • cleanup()方法将当前顶点到源点的最短距离写入结果表中。

      • 54行:当前顶点的Value值(即该顶点到源点的最短路径)未发生变化时,即调用voteToHalt()通过框架使该顶点进入halt状态。当所有顶点都进入halt状态时,计算结束。

      • 71行:定义GraphLoader图数据以有向图的方式加载图数据。通过将表内存储的记录解析为图的顶点或边信息加载至框架内。如上示例代码中,用户可通过addVertexRequest方式将图的顶点信息加载至图计算的上下文内。

      • 90行:定义MinLongCombiner。对发送给同一个点的消息进行合并,优化性能,减少内存占用。

      • 101行:主程序main函数中定义GraphJob。指定VertexGraphLoaderBaseLoadingVertexResolverCombiner等的实现,指定输入输出表。

      • 110行:添加处理冲突的类BaseLoadingVertexResolver

  • 无向图

    import com.aliyun.odps.data.TableInfo;
    import com.aliyun.odps.graph.*;
    import com.aliyun.odps.io.DoubleWritable;
    import com.aliyun.odps.io.LongWritable;
    import com.aliyun.odps.io.WritableRecord;
    
    import java.io.IOException;
    import java.util.HashSet;
    import java.util.Set;
    
    public class SSSPBenchmark4 {
        public static final String START_VERTEX = "sssp.start.vertex.id";
    
        public static class SSSPVertex extends
                Vertex<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> {
            private static long startVertexId = -1;
            public SSSPVertex() {
                this.setValue(new DoubleWritable(Double.MAX_VALUE));
            }
            public boolean isStartVertex(
                    ComputeContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context) {
                if (startVertexId == -1) {
                    String s = context.getConfiguration().get(START_VERTEX);
                    startVertexId = Long.parseLong(s);
                }
                return getId().get() == startVertexId;
            }
    
            @Override
            public void compute(
                    ComputeContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context,
                    Iterable<DoubleWritable> messages) throws IOException {
                double minDist = isStartVertex(context) ? 0 : Double.MAX_VALUE;
                for (DoubleWritable msg : messages) {
                    if (msg.get() < minDist) {
                        minDist = msg.get();
                    }
                }
                if (minDist < this.getValue().get()) {
                    this.setValue(new DoubleWritable(minDist));
                    if (hasEdges()) {
                        for (Edge<LongWritable, DoubleWritable> e : this.getEdges()) {
                            context.sendMessage(e.getDestVertexId(), new DoubleWritable(minDist
                                    + e.getValue().get()));
                        }
                    }
                } else {
                    voteToHalt();
                }
            }
    
            @Override
            public void cleanup(
                    WorkerContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context)
                    throws IOException {
                context.write(getId(), getValue());
            }
        }
    
        public static class MinLongCombiner extends
                Combiner<LongWritable, DoubleWritable> {
            @Override
            public void combine(LongWritable vertexId, DoubleWritable combinedMessage,
                                DoubleWritable messageToCombine) {
                if (combinedMessage.get() > messageToCombine.get()) {
                    combinedMessage.set(messageToCombine.get());
                }
            }
        }
    
        public static class SSSPGraphLoader extends
                GraphLoader<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> {
            @Override
            public void load(
                    LongWritable recordNum,
                    WritableRecord record,
                    MutationContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context)
                    throws IOException {
                LongWritable sourceVertexID = (LongWritable) record.get(0);
                LongWritable destinationVertexID = (LongWritable) record.get(1);
                DoubleWritable edgeValue = (DoubleWritable) record.get(2);
                Edge<LongWritable, DoubleWritable> edge = new Edge<LongWritable, DoubleWritable>(destinationVertexID, edgeValue);
                context.addEdgeRequest(sourceVertexID, edge);
                Edge<LongWritable, DoubleWritable> edge2 = new
                        Edge<LongWritable, DoubleWritable>(sourceVertexID, edgeValue);
                context.addEdgeRequest(destinationVertexID, edge2);
            }
        }
    
        public static class SSSPLoadingVertexResolver extends
                LoadingVertexResolver<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> {
    
            @Override
            public Vertex<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> resolve(
                    LongWritable vertexId,
                    VertexChanges<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> vertexChanges) throws IOException {
    
                SSSPVertex computeVertex = new SSSPVertex();
                computeVertex.setId(vertexId);
                Set<LongWritable> destinationVertexIDSet = new HashSet<>();
    
                if (hasEdgeAdditions(vertexChanges)) {
                    for (Edge<LongWritable, DoubleWritable> edge : vertexChanges.getAddedEdgeList()) {
                        if (!destinationVertexIDSet.contains(edge.getDestVertexId())) {
                            destinationVertexIDSet.add(edge.getDestVertexId());
                            computeVertex.addEdge(edge.getDestVertexId(), edge.getValue());
                        }
    
                    }
                }
    
                return computeVertex;
            }
    
            protected boolean hasEdgeAdditions(VertexChanges<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> changes) {
                return changes != null && changes.getAddedEdgeList() != null
                        && !changes.getAddedEdgeList().isEmpty();
            }
        }
    
        public static void main(String[] args) throws IOException {
            if (args.length < 2) {
                System.out.println("Usage: <startnode> <input> <output>");
                System.exit(-1);
            }
            GraphJob job = new GraphJob();
            job.setGraphLoaderClass(SSSPGraphLoader.class);
            job.setLoadingVertexResolver(SSSPLoadingVertexResolver.class);
            job.setVertexClass(SSSPVertex.class);
            job.setCombinerClass(MinLongCombiner.class);
            job.set(START_VERTEX, args[0]);
            job.addInput(TableInfo.builder().tableName(args[1]).build());
            job.addOutput(TableInfo.builder().tableName(args[2]).build());
            long startTime = System.currentTimeMillis();
            job.run();
            System.out.println("Job Finished in "
                    + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
        }
    }
                        

    代码说明:

    • 15行:定义SSSPVertex。其中:

      • 点值表示该顶点到源点startVertexId的最短距离。

      • compute()方法使用迭代公式d[v]=min(d[v], d[u]+weight(u, v))计算最短距离值并更新至当前点值。

      • cleanup()方法将当前顶点到源点的最短距离写入结果表中。

    • 54行:当前顶点的Value值(即该顶点到源点的最短路径)未发生变化时,即调用voteToHalt()通过框架使该顶点进入halt状态。当所有顶点都进入halt状态时,计算结束。

    • 61行:定义MinLongCombiner。对发送给同一个点的消息进行合并,优化性能,减少内存占用。

    • 72行:定义GraphLoader图数据以无向图的方式加载图数据。通过addEdgeRequest以两点之间的边作为无向边加载边信息,这样便可保证当前表内存储的图信息加载成无向图。

      • 80行:第一列表示初始点的ID。

      • 81行:第二列表示终点的ID。

      • 82行:第三列表示边的权重。

      • 83行:创建边,由终点ID和边的权重组成。

      • 84行:请求给初始点添加边。

      • 85行 - 第87行:每条Record表示双向边,重复第83行与第84行。

    • 定义SSSPLoadingVertexResolver。用于处理无向图数据在加载过程中所遇到的冲突。例如当前的某一边进行了两次添加的过程(即进行了两次addEdgeRequest操作),这种行为便称为冲突加载,需要人为处理重复添加的边才可保证计算正确性。

    • 101行:主程序main函数中定义GraphJob。指定VertexGraphLoaderSSSPLoadingVertexResolverCombiner等的实现,指定输入输出表。

运行结果

以下是基于有向图的代码示例的运行结果。操作详情,请参见编写Graph

vertex    value
1        0
2        2
3        1
4        3
5        2
  • vertex:代表当前顶点。

  • value:代表当前vertex到达源点(1)的最短距离。

说明

无向图数据,用户可以参考无向图代码示例中的初始点ID,终点ID,边的权值自行构造。

示例教程

关于上述示例代码的使用详情,请参见开发Graph