单源最短距离是指给定图中一个源点,计算源点到其它所有节点的最短距离。Dijkstra算法是求解有向图中单源最短距离SSSP(Single Source Shortest Path)的经典算法。
算法原理
Dijkstra算法是通过点去更新最短距离值,每个点维护到源点的当前最短距离值,当这个值发生变化时,将新值加上边的权值,发送消息通知其邻接点。下一轮迭代时,邻接点根据收到的消息,更新其当前最短距离值,当所有点的当前最短距离值不再变化时,迭代结束。
初始化:源点s到s自身的距离为0(
d[s]=0
),其他点u到s的距离为无穷(d[u]=∞
)。迭代:如果存在一条从u到v的边,则从s到v的最短距离更新为
d[v]=min(d[v], d[u]+weight(u, v))
,直到所有的点到s的距离不再发生变化时,迭代结束。
对一个有权重的有向图G=(V,E)
,从一个源点s到汇点v有很多路径,其中边权和最小的路径,称为从s到v的最短距离。
由算法基本原理可以看出,此算法非常适合用MaxCompute Graph程序进行求解。
使用场景
图类型通常分为有向图和无向图两种,MaxCompute均支持。基于源数据的分布,构造有向图和无向图时的路径通路会存在差异,求解SSSP时会产生不同的结果。MaxCompute Graph以有向图为基础数据模型,框架内会以有向图的模型参与计算。
代码示例
以下代码基于不同的场景,提供不同的代码示例。
有向图
定义类
BaseLoadingVertexResolver
,此异常类会在SSSP
类中被引用。import com.aliyun.odps.graph.Edge; import com.aliyun.odps.graph.LoadingVertexResolver; import com.aliyun.odps.graph.Vertex; import com.aliyun.odps.graph.VertexChanges; import com.aliyun.odps.io.Writable; import com.aliyun.odps.io.WritableComparable; import java.io.IOException; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; @SuppressWarnings("rawtypes") public class BaseLoadingVertexResolver<I extends WritableComparable, V extends Writable, E extends Writable, M extends Writable> extends LoadingVertexResolver<I, V, E, M> { @Override public Vertex<I, V, E, M> resolve(I vertexId, VertexChanges<I, V, E, M> vertexChanges) throws IOException { Vertex<I, V, E, M> vertex = addVertexIfDesired(vertexId, vertexChanges); if (vertex != null) { addEdges(vertex, vertexChanges); } else { System.err.println("Ignore all addEdgeRequests for vertex#" + vertexId); } return vertex; } protected Vertex<I, V, E, M> addVertexIfDesired( I vertexId, VertexChanges<I, V, E, M> vertexChanges) { Vertex<I, V, E, M> vertex = null; if (hasVertexAdditions(vertexChanges)) { vertex = vertexChanges.getAddedVertexList().get(0); } return vertex; } protected void addEdges(Vertex<I, V, E, M> vertex, VertexChanges<I, V, E, M> vertexChanges) throws IOException { Set<I> destVertexId = new HashSet<I>(); if (vertex.hasEdges()) { List<Edge<I, E>> edgeList = vertex.getEdges(); for (Iterator<Edge<I, E>> edges = edgeList.iterator(); edges.hasNext(); ) { Edge<I, E> edge = edges.next(); if (destVertexId.contains(edge.getDestVertexId())) { edges.remove(); } else { destVertexId.add(edge.getDestVertexId()); } } } for (Vertex<I, V, E, M> vertex1 : vertexChanges.getAddedVertexList()) { if (vertex1.hasEdges()) { List<Edge<I, E>> edgeList = vertex1.getEdges(); for (Edge<I, E> edge : edgeList) { if (destVertexId.contains(edge.getDestVertexId())) continue; destVertexId.add(edge.getDestVertexId()); vertex.addEdge(edge.getDestVertexId(), edge.getValue()); } } } } protected boolean hasVertexAdditions(VertexChanges<I, V, E, M> changes) { return changes != null && changes.getAddedVertexList() != null && !changes.getAddedVertexList().isEmpty(); } }
代码说明:
第15行:定义BaseLoadingVertexResolver。用于处理有向图数据在加载过程中所遇到的冲突。
第18行:resolve为处理冲突的具体方法。例如当前的某一顶点进行了两次添加的过程(即进行了两次addVertexRequest操作),这种行为便称为冲突加载,需要人为处理冲突之后再参与计算。
定义类
SSSP
。import java.io.IOException; import com.aliyun.odps.graph.Combiner; import com.aliyun.odps.graph.ComputeContext; import com.aliyun.odps.graph.Edge; import com.aliyun.odps.graph.GraphJob; import com.aliyun.odps.graph.GraphLoader; import com.aliyun.odps.graph.MutationContext; import com.aliyun.odps.graph.Vertex; import com.aliyun.odps.graph.WorkerContext; import com.aliyun.odps.io.WritableRecord; import com.aliyun.odps.io.LongWritable; import com.aliyun.odps.data.TableInfo; public class SSSP { public static final String START_VERTEX = "sssp.start.vertex.id"; public static class SSSPVertex extends Vertex<LongWritable, LongWritable, LongWritable, LongWritable> { private static long startVertexId = -1; public SSSPVertex() { this.setValue(new LongWritable(Long.MAX_VALUE)); } public boolean isStartVertex( ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context) { if (startVertexId == -1) { String s = context.getConfiguration().get(START_VERTEX); startVertexId = Long.parseLong(s); } return getId().get() == startVertexId; } @Override public void compute( ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context, Iterable<LongWritable> messages) throws IOException { long minDist = isStartVertex(context) ? 0 : Long.MAX_VALUE; for (LongWritable msg : messages) { if (msg.get() < minDist) { minDist = msg.get(); } } if (minDist < this.getValue().get()) { this.setValue(new LongWritable(minDist)); if (hasEdges()) { for (Edge<LongWritable, LongWritable> e : this.getEdges()) { context.sendMessage(e.getDestVertexId(), new LongWritable(minDist + e.getValue().get())); } } } else { voteToHalt(); } } @Override public void cleanup( WorkerContext<LongWritable, LongWritable, LongWritable, LongWritable> context) throws IOException { context.write(getId(), getValue()); } @Override public String toString() { return "Vertex(id=" + this.getId() + ",value=" + this.getValue() + ",#edges=" + this.getEdges() + ")"; } } public static class SSSPGraphLoader extends GraphLoader<LongWritable, LongWritable, LongWritable, LongWritable> { @Override public void load( LongWritable recordNum, WritableRecord record, MutationContext<LongWritable, LongWritable, LongWritable, LongWritable> context) throws IOException { SSSPVertex vertex = new SSSPVertex(); vertex.setId((LongWritable) record.get(0)); String[] edges = record.get(1).toString().split(","); for (String edge : edges) { String[] ss = edge.split(":"); vertex.addEdge(new LongWritable(Long.parseLong(ss[0])), new LongWritable(Long.parseLong(ss[1]))); } context.addVertexRequest(vertex); } } public static class MinLongCombiner extends Combiner<LongWritable, LongWritable> { @Override public void combine(LongWritable vertexId, LongWritable combinedMessage, LongWritable messageToCombine) throws IOException { if (combinedMessage.get() > messageToCombine.get()) { combinedMessage.set(messageToCombine.get()); } } } public static void main(String[] args) throws IOException { if (args.length < 3) { System.out.println("Usage: <startnode> <input> <output>"); System.exit(-1); } GraphJob job = new GraphJob(); job.setGraphLoaderClass(SSSPGraphLoader.class); job.setVertexClass(SSSPVertex.class); job.setCombinerClass(MinLongCombiner.class); job.setLoadingVertexResolver(BaseLoadingVertexResolver.class); job.set(START_VERTEX, args[0]); job.addInput(TableInfo.builder().tableName(args[1]).build()); job.addOutput(TableInfo.builder().tableName(args[2]).build()); long startTime = System.currentTimeMillis(); job.run(); System.out.println("Job Finished in " + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds"); } }
代码说明:
第19行:定义SSSPVertex。其中:
点值表示该顶点到源点startVertexId的最短距离。
compute()方法使用迭代公式
d[v]=min(d[v], d[u]+weight(u, v))
计算最短距离值并更新至当前点值。cleanup()方法将当前顶点到源点的最短距离写入结果表中。
第54行:当前顶点的Value值(即该顶点到源点的最短路径)未发生变化时,即调用voteToHalt()通过框架使该顶点进入halt状态。当所有顶点都进入halt状态时,计算结束。
第71行:定义GraphLoader图数据以有向图的方式加载图数据。通过将表内存储的记录解析为图的顶点或边信息加载至框架内。如上示例代码中,用户可通过addVertexRequest方式将图的顶点信息加载至图计算的上下文内。
第90行:定义MinLongCombiner。对发送给同一个点的消息进行合并,优化性能,减少内存占用。
第101行:主程序main函数中定义GraphJob。指定Vertex、GraphLoader、BaseLoadingVertexResolver、Combiner等的实现,指定输入输出表。
第110行:添加处理冲突的类BaseLoadingVertexResolver。
无向图
import com.aliyun.odps.data.TableInfo; import com.aliyun.odps.graph.*; import com.aliyun.odps.io.DoubleWritable; import com.aliyun.odps.io.LongWritable; import com.aliyun.odps.io.WritableRecord; import java.io.IOException; import java.util.HashSet; import java.util.Set; public class SSSPBenchmark4 { public static final String START_VERTEX = "sssp.start.vertex.id"; public static class SSSPVertex extends Vertex<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> { private static long startVertexId = -1; public SSSPVertex() { this.setValue(new DoubleWritable(Double.MAX_VALUE)); } public boolean isStartVertex( ComputeContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context) { if (startVertexId == -1) { String s = context.getConfiguration().get(START_VERTEX); startVertexId = Long.parseLong(s); } return getId().get() == startVertexId; } @Override public void compute( ComputeContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context, Iterable<DoubleWritable> messages) throws IOException { double minDist = isStartVertex(context) ? 0 : Double.MAX_VALUE; for (DoubleWritable msg : messages) { if (msg.get() < minDist) { minDist = msg.get(); } } if (minDist < this.getValue().get()) { this.setValue(new DoubleWritable(minDist)); if (hasEdges()) { for (Edge<LongWritable, DoubleWritable> e : this.getEdges()) { context.sendMessage(e.getDestVertexId(), new DoubleWritable(minDist + e.getValue().get())); } } } else { voteToHalt(); } } @Override public void cleanup( WorkerContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context) throws IOException { context.write(getId(), getValue()); } } public static class MinLongCombiner extends Combiner<LongWritable, DoubleWritable> { @Override public void combine(LongWritable vertexId, DoubleWritable combinedMessage, DoubleWritable messageToCombine) { if (combinedMessage.get() > messageToCombine.get()) { combinedMessage.set(messageToCombine.get()); } } } public static class SSSPGraphLoader extends GraphLoader<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> { @Override public void load( LongWritable recordNum, WritableRecord record, MutationContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context) throws IOException { LongWritable sourceVertexID = (LongWritable) record.get(0); LongWritable destinationVertexID = (LongWritable) record.get(1); DoubleWritable edgeValue = (DoubleWritable) record.get(2); Edge<LongWritable, DoubleWritable> edge = new Edge<LongWritable, DoubleWritable>(destinationVertexID, edgeValue); context.addEdgeRequest(sourceVertexID, edge); Edge<LongWritable, DoubleWritable> edge2 = new Edge<LongWritable, DoubleWritable>(sourceVertexID, edgeValue); context.addEdgeRequest(destinationVertexID, edge2); } } public static class SSSPLoadingVertexResolver extends LoadingVertexResolver<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> { @Override public Vertex<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> resolve( LongWritable vertexId, VertexChanges<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> vertexChanges) throws IOException { SSSPVertex computeVertex = new SSSPVertex(); computeVertex.setId(vertexId); Set<LongWritable> destinationVertexIDSet = new HashSet<>(); if (hasEdgeAdditions(vertexChanges)) { for (Edge<LongWritable, DoubleWritable> edge : vertexChanges.getAddedEdgeList()) { if (!destinationVertexIDSet.contains(edge.getDestVertexId())) { destinationVertexIDSet.add(edge.getDestVertexId()); computeVertex.addEdge(edge.getDestVertexId(), edge.getValue()); } } } return computeVertex; } protected boolean hasEdgeAdditions(VertexChanges<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> changes) { return changes != null && changes.getAddedEdgeList() != null && !changes.getAddedEdgeList().isEmpty(); } } public static void main(String[] args) throws IOException { if (args.length < 2) { System.out.println("Usage: <startnode> <input> <output>"); System.exit(-1); } GraphJob job = new GraphJob(); job.setGraphLoaderClass(SSSPGraphLoader.class); job.setLoadingVertexResolver(SSSPLoadingVertexResolver.class); job.setVertexClass(SSSPVertex.class); job.setCombinerClass(MinLongCombiner.class); job.set(START_VERTEX, args[0]); job.addInput(TableInfo.builder().tableName(args[1]).build()); job.addOutput(TableInfo.builder().tableName(args[2]).build()); long startTime = System.currentTimeMillis(); job.run(); System.out.println("Job Finished in " + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds"); } }
代码说明:
第15行:定义SSSPVertex。其中:
点值表示该顶点到源点startVertexId的最短距离。
compute()方法使用迭代公式
d[v]=min(d[v], d[u]+weight(u, v))
计算最短距离值并更新至当前点值。cleanup()方法将当前顶点到源点的最短距离写入结果表中。
第54行:当前顶点的Value值(即该顶点到源点的最短路径)未发生变化时,即调用voteToHalt()通过框架使该顶点进入halt状态。当所有顶点都进入halt状态时,计算结束。
第61行:定义MinLongCombiner。对发送给同一个点的消息进行合并,优化性能,减少内存占用。
第72行:定义GraphLoader图数据以无向图的方式加载图数据。通过addEdgeRequest以两点之间的边作为无向边加载边信息,这样便可保证当前表内存储的图信息加载成无向图。
第80行:第一列表示初始点的ID。
第81行:第二列表示终点的ID。
第82行:第三列表示边的权重。
第83行:创建边,由终点ID和边的权重组成。
第84行:请求给初始点添加边。
第85行 - 第87行:每条Record表示双向边,重复第83行与第84行。
定义SSSPLoadingVertexResolver。用于处理无向图数据在加载过程中所遇到的冲突。例如当前的某一边进行了两次添加的过程(即进行了两次addEdgeRequest操作),这种行为便称为冲突加载,需要人为处理重复添加的边才可保证计算正确性。
第101行:主程序main函数中定义GraphJob。指定Vertex、GraphLoader、SSSPLoadingVertexResolver、Combiner等的实现,指定输入输出表。
运行结果
以下是基于有向图的代码示例的运行结果。操作详情,请参见编写Graph。
vertex value
1 0
2 2
3 1
4 3
5 2
vertex:代表当前顶点。
value:代表当前vertex到达源点(1)的最短距离。
无向图数据,用户可以参考无向图代码示例中的初始点ID,终点ID,边的权值自行构造。