从零开始设计和实现一个 Python 下的 DAG(有向无环图)

我们一起来从零开始设计和实现一个 Python 下的 DAG(有向无环图),并结合 GitHub 上常见的代码模式进行优化。

第一步:理解 DAG 的基本概念和需求

首先,我们需要明确 DAG 的核心概念:

  • 节点(Node): 代表任务或者操作。
  • 有向边(Directed Edge): 表示节点之间的依赖关系,从一个节点指向另一个节点,意味着前者必须在后者之前完成。
  • 无环(Acyclic): 图中不存在从某个节点出发,经过一系列边最终回到该节点自身的路径。这是 DAG 的关键特性。
  • 我们的目标是实现一个 Python 类,能够:

    1. 添加节点: 允许用户向 DAG 中添加任务节点。
    2. 添加边: 允许用户定义节点之间的依赖关系。
    3. 执行 DAG: 按照依赖关系执行节点代表的任务。
    4. 检测环: 在添加边时或执行前检测是否存在环。

    第二步:初步设计 – 核心数据结构

    在 Python 中,表示图最常用的方式是使用邻接表。对于 DAG,我们可以使用字典来实现邻接表,其中:

  • 键(Key): 代表一个节点。
  • 值(Value): 是一个列表,包含该节点的所有后继节点(依赖于该节点的节点)。
  • 同时,为了方便反向查找依赖关系,我们也可以维护一个反向邻接表:

  • 键(Key): 代表一个节点。
  • 值(Value): 是一个列表,包含所有指向该节点的节点(该节点依赖的节点)。
  • 此外,为了存储节点代表的任务(函数或其他可执行对象),我们可以使用另一个字典:

  • 键(Key): 代表一个节点。
  • 值(Value): 代表与该节点关联的任务。
  • 初步代码框架:

    
    

    python

    代码解读

    复制代码

    class DAG: def __init__(self): self._graph = {} # 邻接表:节点 -> [后继节点] self._reverse_graph = {} # 反向邻接表:节点 -> [前驱节点] self._tasks = {} # 节点 -> 任务 def add_node(self, node, task=None): if node not in self._graph: self._graph[node] = [] self._reverse_graph[node] = [] self._tasks[node] = task def add_edge(self, from_node, to_node): if from_node not in self._graph or to_node not in self._graph: raise ValueError("节点不存在") if to_node not in self._graph[from_node]: self._graph[from_node].append(to_node) self._reverse_graph[to_node].append(from_node) # 思考:这里是否需要进行环检测?

    第三步:完善添加边和环检测

    在 add_edge 方法中,我们需要考虑环检测。一个常见的环检测方法是使用深度优先搜索(DFS)。

    环检测思路:

    1. 维护三个集合:
    2. visited: 已经访问过的节点。
    3. visiting: 当前正在访问的节点。
    4. 从每个节点开始进行 DFS。
    5. 在 DFS 过程中,如果遇到一个节点已经在 visiting 集合中,则说明存在环。
    6. 当一个节点的所有后继节点都访问完毕后,将其从 visiting 集合中移除,并加入 visited 集合。

    改进后的 add_edge 方法:

    
    

    python

    代码解读

    复制代码

    class DAG: # ... (之前的代码) def add_edge(self, from_node, to_node): if from_node not in self._graph or to_node not in self._graph: raise ValueError("节点不存在") if to_node not in self._graph[from_node]: self._graph[from_node].append(to_node) self._reverse_graph[to_node].append(from_node) if self._has_cycle(): # 如果添加边导致环,则撤销操作 self._graph[from_node].remove(to_node) self._reverse_graph[to_node].remove(from_node) raise ValueError("添加边会导致环") def _has_cycle(self): visited = set() visiting = set() def _dfs(node): visiting.add(node) for neighbor in self._graph.get(node, []): if neighbor in visiting: return True if neighbor not in visited: if _dfs(neighbor): return True visiting.remove(node) visited.add(node) return False for node in self._graph: if node not in visited: if _dfs(node): return True return False

    思考与优化 1:

  • 环检测的时机: 我们选择在每次添加边之后进行环检测,这可以尽早发现问题。另一种策略是在执行 DAG 前进行一次性检测。选择哪种方式取决于对性能和错误反馈的需求。频繁检测会增加开销,但能提供更即时的错误信息。
  • 环检测算法: DFS 是一种常见的环检测方法,但对于大型图,可能需要考虑更高效的算法。
  • 第四步:实现 DAG 的执行

    执行 DAG 的核心是按照依赖关系排序节点,这可以通过拓扑排序算法实现。

    拓扑排序思路:

    1. 计算每个节点的入度(指向该节点的边的数量)。
    2. 将所有入度为 0 的节点放入一个队列。
    3. 当队列不为空时:
    4. 从队列中取出一个节点。
    5. 执行该节点对应的任务。
    6. 将该节点的所有后继节点的入度减 1。
    7. 如果某个后继节点的入度变为 0,则将其加入队列。
    8. 如果所有节点都被处理,则执行成功。否则,图中存在环(这应该在添加边时就被检测出来)。

    实现 execute 方法:

    
    

    python

    代码解读

    复制代码

    from collections import deque class DAG: # ... (之前的代码) def execute(self): in_degree = {node: len(self._reverse_graph[node]) for node in self._graph} queue = deque([node for node in self._graph if in_degree[node] == 0]) executed_nodes = [] while queue: node = queue.popleft() print(f"执行节点: {node}") task = self._tasks.get(node) if task: task() # 执行任务 executed_nodes.append(node) for neighbor in self._graph.get(node, []): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) if len(executed_nodes) != len(self._graph): raise RuntimeError("图中存在环,无法完成拓扑排序") # 理论上不会发生,因为添加边时已检测

    思考与优化 2:

  • 任务执行: 目前的任务执行是简单的函数调用。在实际应用中,任务可能需要传递参数、处理返回值、进行错误处理等。
  • 并行执行: 对于相互独立的节点,可以并行执行以提高效率。可以使用 threading 或 asyncio 模块来实现。
  • 执行顺序: 拓扑排序保证了依赖关系的正确性,但对于没有依赖关系的节点,执行顺序可能不确定。如果需要特定的执行顺序,可以进行额外的排序或优先级控制。
  • 第五步:添加更灵活的任务定义和执行

    目前,我们假设任务是简单的无参函数。为了更灵活地处理各种任务,我们可以允许用户在添加节点时传递任意可调用对象,并允许在执行时传递参数。

    改进后的 add_node 和 execute 方法:

    
    

    python

    代码解读

    复制代码

    class DAG: # ... (之前的代码) def add_node(self, node, task=None, *args, **kwargs): if node not in self._graph: self._graph[node] = [] self._reverse_graph[node] = [] self._tasks[node] = (task, args, kwargs) # 存储任务和参数 def execute(self): in_degree = {node: len(self._reverse_graph[node]) for node in self._graph} queue = deque([node for node in self._graph if in_degree[node] == 0]) executed_nodes = {} # 存储执行结果 while queue: node = queue.popleft() task_info = self._tasks.get(node) if task_info: task, args, kwargs = task_info print(f"执行节点: {node}, 任务: {task.__name__ if callable(task) else task}") try: result = task(*args, **kwargs) # 执行任务并获取结果 executed_nodes[node] = result except Exception as e: print(f"节点 {node} 执行失败: {e}") raise # 可以选择抛出异常或继续执行 for neighbor in self._graph.get(node, []): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) if len(executed_nodes) != len(self._graph): raise RuntimeError("图中存在环或部分节点未执行") return executed_nodes # 返回执行结果

    思考与优化 3:

  • 任务参数传递: 允许在添加节点时传递参数,使得任务可以接收特定的输入。
  • 任务执行结果: 存储每个节点的执行结果,方便后续节点使用。
  • 错误处理: 在任务执行过程中添加了 try-except 块,可以捕获异常并进行处理。可以根据需求选择抛出异常、记录日志或跳过该节点。
  • 依赖注入: 如果任务之间需要传递数据,可以通过执行结果来实现简单的依赖注入。例如,一个节点的输出可以作为另一个节点的输入。
  • 第六步:借鉴 GitHub 常见代码模式进行优化

    在 GitHub 上,常见的代码模式可以帮助我们提高代码的可读性、可维护性和性能。

  • 使用装饰器: 可以使用装饰器来简化任务的添加和定义。
  • 上下文管理器: 可以使用上下文管理器来管理资源的分配和释放。
  • 生成器: 可以使用生成器来处理大型数据集或异步操作。
  • 类型提示: 使用类型提示可以提高代码的可读性和可维护性,并帮助静态类型检查工具发现错误。
  • 单元测试: 编写单元测试来验证 DAG 的各个功能是否正常工作。
  • 示例:使用装饰器简化任务添加

    
    

    python

    代码解读

    复制代码

    class DAG: # ... (之前的代码) def task(self, node, *args, **kwargs): def decorator(func): self.add_node(node, func, *args, **kwargs) return func return decorator # 使用装饰器添加任务 dag = DAG() @dag.task("task_a") def task_a(): print("执行任务 A") return "result_a" @dag.task("task_b", input_value=10) def task_b(input_value): print(f"执行任务 B,输入: {input_value}") return input_value * 2 dag.add_edge("task_a", "task_b") dag.execute()

    思考与优化 4:

  • 模块化设计: 将 DAG 的不同功能模块化,例如节点管理、边管理、执行引擎、环检测等,提高代码的可维护性。
  • 配置化: 允许通过配置文件或外部数据源来定义 DAG 的结构和任务,提高灵活性。
  • 可视化: 提供可视化 DAG 结构的功能,方便用户理解和调试。可以使用 graphviz 等库。
  • 异步执行: 使用 asyncio 或 concurrent.futures 实现异步并行执行,提高性能。
  • 第七步:进一步的思考和扩展

  • 错误处理策略: 更精细的错误处理,例如重试机制、回滚操作、依赖失败处理等。
  • 数据传递和共享: 更复杂的数据传递机制,例如使用消息队列或共享内存。
  • 状态管理: 跟踪 DAG 的执行状态,例如节点的状态(等待、运行中、已完成、失败)。
  • 监控和日志: 集成监控和日志功能,方便观察 DAG 的运行情况。
  • 与其他工具集成: 例如与工作流引擎(如 Airflow、Luigi)集成。
  • 总结与最终代码(包含一些优化):

    
    

    python

    代码解读

    复制代码

    from collections import deque from functools import wraps class DAG: def __init__(self): self._graph = {} self._reverse_graph = {} self._tasks = {} def add_node(self, node, task=None, *args, **kwargs): if node not in self._graph: self._graph[node] = [] self._reverse_graph[node] = [] self._tasks[node] = (task, args, kwargs) def add_edge(self, from_node, to_node): if from_node not in self._graph or to_node not in self._graph: raise ValueError("节点不存在") if to_node not in self._graph[from_node]: self._graph[from_node].append(to_node) self._reverse_graph[to_node].append(from_node) if self._has_cycle(): self._graph[from_node].remove(to_node) self._reverse_graph[to_node].remove(from_node) raise ValueError("添加边会导致环") def _has_cycle(self): visited = set() visiting = set() def _dfs(node): visiting.add(node) for neighbor in self._graph.get(node, []): if neighbor in visiting: return True if neighbor not in visited: if _dfs(neighbor): return True visiting.remove(node) visited.add(node) return False for node in self._graph: if node not in visited: if _dfs(node): return True return False def task(self, node, *args, **kwargs): def decorator(func): self.add_node(node, func, *args, **kwargs) @wraps(func) def wrapper(*_args, **_kwargs): return func(*_args, **_kwargs) return wrapper return decorator def execute(self): in_degree = {node: len(self._reverse_graph[node]) for node in self._graph} queue = deque([node for node in self._graph if in_degree[node] == 0]) executed_nodes = {} while queue: node = queue.popleft() task_info = self._tasks.get(node) if task_info: task, args, kwargs = task_info print(f"执行节点: {node}, 任务: {task.__name__ if callable(task) else task}") try: result = task(*args, **kwargs) executed_nodes[node] = result except Exception as e: print(f"节点 {node} 执行失败: {e}") raise for neighbor in self._graph.get(node, []): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) if len(executed_nodes) != len(self._graph): raise RuntimeError("图中存在环或部分节点未执行") return executed_nodes # 示例用法 dag = DAG() @dag.task("start") def start_task(): print("开始任务") return 10 @dag.task("process", multiplier=2) def process_task(value, multiplier): print(f"处理任务,值: {value},乘数: {multiplier}") return value * multiplier @dag.task("end") def end_task(value): print(f"结束任务,最终值: {value}") dag.add_edge("start", "process") dag.add_edge("process", "end") results = dag.execute() print("执行结果:", results)

    通过这个过程,我们从零开始设计并实现了一个 Python 下的 DAG,并逐步进行了优化和扩展

    作者:Java八股文面试

    物联沃分享整理
    物联沃-IOTWORD物联网 » 从零开始设计和实现一个 Python 下的 DAG(有向无环图)

    发表回复