CommDebugMode 入门#
创建日期:2024 年 8 月 19 日 | 最后更新:2024 年 10 月 8 日 | 最后验证:2024 年 11 月 5 日
作者:Anshul Sinha
在本教程中,我们将探讨如何通过追踪分布式训练环境中的集合通信操作,使用 CommDebugMode 和 PyTorch 的分布式张量(DistributedTensor,简称 DTensor)进行调试。
先决条件#
Python 3.8 - 3.11
PyTorch 2.2 或更高版本
什么是 CommDebugMode 以及它为何有用#
随着模型规模的不断增长,用户正寻求利用各种并行策略组合来扩展分布式训练。然而,现有解决方案之间缺乏互操作性构成了一个重大挑战,这主要是因为缺乏一种能够桥接这些不同并行策略的统一抽象。为了解决这个问题,PyTorch 推出了 DistributedTensor(DTensor),它抽象了分布式训练中张量通信的复杂性,提供了无缝的用户体验。然而,在处理现有并行解决方案以及使用 DTensor 等统一抽象开发并行方案时,由于底层集合通信发生的时间和内容缺乏透明度,高级用户在识别和解决问题时可能会感到困难。为了应对这一挑战,Python 上下文管理器 CommDebugMode 将作为 DTensor 的主要调试工具之一,使用户能够查看使用 DTensor 时集合通信操作发生的时间和原因,从而有效解决这一问题。
使用 CommDebugMode#
以下是如何使用 CommDebugMode 的方法
# The model used in this example is a MLPModule applying Tensor Parallel
comm_mode = CommDebugMode()
with comm_mode:
output = model(inp)
# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
noise_level=1, file_name="transformer_operation_log.txt"
)
# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)
这是 MLPModule 在噪声级别 0 时的输出示例
Expected Output:
Global
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule.net1
MLPModule.relu
MLPModule.net2
FORWARD PASS
*c10d_functional.all_reduce: 1
要使用 CommDebugMode,必须将运行模型的代码包裹在 CommDebugMode 中,并调用用于显示数据的 API。你还可以使用 noise_level 参数来控制所显示信息的详细程度。各噪声级别显示的内容如下:
在上面的例子中,你可以看到集合通信操作 all_reduce 在 MLPModule 的前向传播中发生了一次。此外,你可以使用 CommDebugMode 精确指出 all_reduce 操作发生在 MLPModule 的第二个线性层中。
下方是交互式模块树可视化工具,你可以使用它上传自己的 JSON 转储文件
结论#
在本指南中,我们学习了如何使用 CommDebugMode 来调试分布式张量以及使用 PyTorch 通信集合的并行解决方案。你可以在嵌入式可视化浏览器中使用你自己的 JSON 输出文件。
有关 CommDebugMode 的更多详细信息,请参阅 comm_mode_features_example.py