从零开始设计和实现一个 Python 下的 DAG(有向无环图)
我们一起来从零开始设计和实现一个 Python 下的 DAG(有向无环图),并结合 GitHub 上常见的代码模式进行优化。
第一步:理解 DAG 的基本概念和需求
首先,我们需要明确 DAG 的核心概念:
我们的目标是实现一个 Python 类,能够:
- 添加节点: 允许用户向 DAG 中添加任务节点。
- 添加边: 允许用户定义节点之间的依赖关系。
- 执行 DAG: 按照依赖关系执行节点代表的任务。
- 检测环: 在添加边时或执行前检测是否存在环。
第二步:初步设计 – 核心数据结构
在 Python 中,表示图最常用的方式是使用邻接表。对于 DAG,我们可以使用字典来实现邻接表,其中:
同时,为了方便反向查找依赖关系,我们也可以维护一个反向邻接表:
此外,为了存储节点代表的任务(函数或其他可执行对象),我们可以使用另一个字典:
初步代码框架:
python
代码解读
复制代码
class DAG: def __init__(self): self._graph = {} # 邻接表:节点 -> [后继节点] self._reverse_graph = {} # 反向邻接表:节点 -> [前驱节点] self._tasks = {} # 节点 -> 任务 def add_node(self, node, task=None): if node not in self._graph: self._graph[node] = [] self._reverse_graph[node] = [] self._tasks[node] = task def add_edge(self, from_node, to_node): if from_node not in self._graph or to_node not in self._graph: raise ValueError("节点不存在") if to_node not in self._graph[from_node]: self._graph[from_node].append(to_node) self._reverse_graph[to_node].append(from_node) # 思考:这里是否需要进行环检测?
第三步:完善添加边和环检测
在 add_edge
方法中,我们需要考虑环检测。一个常见的环检测方法是使用深度优先搜索(DFS)。
环检测思路:
- 维护三个集合:
visited
: 已经访问过的节点。visiting
: 当前正在访问的节点。- 从每个节点开始进行 DFS。
- 在 DFS 过程中,如果遇到一个节点已经在
visiting
集合中,则说明存在环。 - 当一个节点的所有后继节点都访问完毕后,将其从
visiting
集合中移除,并加入visited
集合。
改进后的 add_edge
方法:
python
代码解读
复制代码
class DAG: # ... (之前的代码) def add_edge(self, from_node, to_node): if from_node not in self._graph or to_node not in self._graph: raise ValueError("节点不存在") if to_node not in self._graph[from_node]: self._graph[from_node].append(to_node) self._reverse_graph[to_node].append(from_node) if self._has_cycle(): # 如果添加边导致环,则撤销操作 self._graph[from_node].remove(to_node) self._reverse_graph[to_node].remove(from_node) raise ValueError("添加边会导致环") def _has_cycle(self): visited = set() visiting = set() def _dfs(node): visiting.add(node) for neighbor in self._graph.get(node, []): if neighbor in visiting: return True if neighbor not in visited: if _dfs(neighbor): return True visiting.remove(node) visited.add(node) return False for node in self._graph: if node not in visited: if _dfs(node): return True return False
思考与优化 1:
第四步:实现 DAG 的执行
执行 DAG 的核心是按照依赖关系排序节点,这可以通过拓扑排序算法实现。
拓扑排序思路:
- 计算每个节点的入度(指向该节点的边的数量)。
- 将所有入度为 0 的节点放入一个队列。
- 当队列不为空时:
- 从队列中取出一个节点。
- 执行该节点对应的任务。
- 将该节点的所有后继节点的入度减 1。
- 如果某个后继节点的入度变为 0,则将其加入队列。
- 如果所有节点都被处理,则执行成功。否则,图中存在环(这应该在添加边时就被检测出来)。
实现 execute
方法:
python
代码解读
复制代码
from collections import deque class DAG: # ... (之前的代码) def execute(self): in_degree = {node: len(self._reverse_graph[node]) for node in self._graph} queue = deque([node for node in self._graph if in_degree[node] == 0]) executed_nodes = [] while queue: node = queue.popleft() print(f"执行节点: {node}") task = self._tasks.get(node) if task: task() # 执行任务 executed_nodes.append(node) for neighbor in self._graph.get(node, []): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) if len(executed_nodes) != len(self._graph): raise RuntimeError("图中存在环,无法完成拓扑排序") # 理论上不会发生,因为添加边时已检测
思考与优化 2:
threading
或 asyncio
模块来实现。第五步:添加更灵活的任务定义和执行
目前,我们假设任务是简单的无参函数。为了更灵活地处理各种任务,我们可以允许用户在添加节点时传递任意可调用对象,并允许在执行时传递参数。
改进后的 add_node
和 execute
方法:
python
代码解读
复制代码
class DAG: # ... (之前的代码) def add_node(self, node, task=None, *args, **kwargs): if node not in self._graph: self._graph[node] = [] self._reverse_graph[node] = [] self._tasks[node] = (task, args, kwargs) # 存储任务和参数 def execute(self): in_degree = {node: len(self._reverse_graph[node]) for node in self._graph} queue = deque([node for node in self._graph if in_degree[node] == 0]) executed_nodes = {} # 存储执行结果 while queue: node = queue.popleft() task_info = self._tasks.get(node) if task_info: task, args, kwargs = task_info print(f"执行节点: {node}, 任务: {task.__name__ if callable(task) else task}") try: result = task(*args, **kwargs) # 执行任务并获取结果 executed_nodes[node] = result except Exception as e: print(f"节点 {node} 执行失败: {e}") raise # 可以选择抛出异常或继续执行 for neighbor in self._graph.get(node, []): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) if len(executed_nodes) != len(self._graph): raise RuntimeError("图中存在环或部分节点未执行") return executed_nodes # 返回执行结果
思考与优化 3:
try-except
块,可以捕获异常并进行处理。可以根据需求选择抛出异常、记录日志或跳过该节点。第六步:借鉴 GitHub 常见代码模式进行优化
在 GitHub 上,常见的代码模式可以帮助我们提高代码的可读性、可维护性和性能。
示例:使用装饰器简化任务添加
python
代码解读
复制代码
class DAG: # ... (之前的代码) def task(self, node, *args, **kwargs): def decorator(func): self.add_node(node, func, *args, **kwargs) return func return decorator # 使用装饰器添加任务 dag = DAG() @dag.task("task_a") def task_a(): print("执行任务 A") return "result_a" @dag.task("task_b", input_value=10) def task_b(input_value): print(f"执行任务 B,输入: {input_value}") return input_value * 2 dag.add_edge("task_a", "task_b") dag.execute()
思考与优化 4:
graphviz
等库。asyncio
或 concurrent.futures
实现异步并行执行,提高性能。第七步:进一步的思考和扩展
总结与最终代码(包含一些优化):
python
代码解读
复制代码
from collections import deque from functools import wraps class DAG: def __init__(self): self._graph = {} self._reverse_graph = {} self._tasks = {} def add_node(self, node, task=None, *args, **kwargs): if node not in self._graph: self._graph[node] = [] self._reverse_graph[node] = [] self._tasks[node] = (task, args, kwargs) def add_edge(self, from_node, to_node): if from_node not in self._graph or to_node not in self._graph: raise ValueError("节点不存在") if to_node not in self._graph[from_node]: self._graph[from_node].append(to_node) self._reverse_graph[to_node].append(from_node) if self._has_cycle(): self._graph[from_node].remove(to_node) self._reverse_graph[to_node].remove(from_node) raise ValueError("添加边会导致环") def _has_cycle(self): visited = set() visiting = set() def _dfs(node): visiting.add(node) for neighbor in self._graph.get(node, []): if neighbor in visiting: return True if neighbor not in visited: if _dfs(neighbor): return True visiting.remove(node) visited.add(node) return False for node in self._graph: if node not in visited: if _dfs(node): return True return False def task(self, node, *args, **kwargs): def decorator(func): self.add_node(node, func, *args, **kwargs) @wraps(func) def wrapper(*_args, **_kwargs): return func(*_args, **_kwargs) return wrapper return decorator def execute(self): in_degree = {node: len(self._reverse_graph[node]) for node in self._graph} queue = deque([node for node in self._graph if in_degree[node] == 0]) executed_nodes = {} while queue: node = queue.popleft() task_info = self._tasks.get(node) if task_info: task, args, kwargs = task_info print(f"执行节点: {node}, 任务: {task.__name__ if callable(task) else task}") try: result = task(*args, **kwargs) executed_nodes[node] = result except Exception as e: print(f"节点 {node} 执行失败: {e}") raise for neighbor in self._graph.get(node, []): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) if len(executed_nodes) != len(self._graph): raise RuntimeError("图中存在环或部分节点未执行") return executed_nodes # 示例用法 dag = DAG() @dag.task("start") def start_task(): print("开始任务") return 10 @dag.task("process", multiplier=2) def process_task(value, multiplier): print(f"处理任务,值: {value},乘数: {multiplier}") return value * multiplier @dag.task("end") def end_task(value): print(f"结束任务,最终值: {value}") dag.add_edge("start", "process") dag.add_edge("process", "end") results = dag.execute() print("执行结果:", results)
通过这个过程,我们从零开始设计并实现了一个 Python 下的 DAG,并逐步进行了优化和扩展
作者:Java八股文面试