用于调试卡死任务的飞行记录器#
作者: Chirag Pandya, Junjie Wang
您将学到什么#
了解一种用于调试分布式训练中卡死任务的新工具。
学习如何启用该工具,并使用收集到的数据来分析卡死任务。
先决条件#
PyTorch 2.5 或更高版本。
tabulate。你可以通过运行
pip install tabulate进行安装。
概述#
AI 分布式训练任务是指在网络连接的多个设备(如 GPU 或 CPU)上训练机器学习模型的过程。这种方法可以更快、更高效地训练需要大量计算资源的大型模型。工程师的目标是尽可能快地完成 AI 训练任务,并进行持续改进,以便后续训练能更快完成。一个经过训练、可用的模型是最终期望的结果。完成训练的最大障碍之一就是“卡死任务”的概念。
当分布式 AI 训练任务在长时间内无法取得有意义的进展时,即被视为卡死。
任务卡死的原因多种多样:
数据饥饿 (Data Starvation): 当训练任务无法以预期的速率接收数据时,就会发生这种情况,这可能是由于数据流水线或数据源的问题导致的。
资源限制 (Resource Constraints): 如果运行任务的系统没有足够的计算资源(如 CPU、GPU 或内存),任务可能无法继续。
网络问题 (Network Issues): 在分布式训练设置中,模型的不同部分或数据可能在不同的设备上处理。如果存在网络问题,这些设备之间的通信可能会中断,导致任务卡死。
软件错误或缺陷 (Software Bugs or Errors): 训练代码或底层库和框架中的错误也可能导致任务卡死。
同步问题 (Synchronization Issues): 在分布式训练中,计算的不同部分通常并行运行,需要在某些点进行同步。如果这种同步失败,任务可能会卡死。例如,如果一个或多个进程组 (rank) 未能加入集合通信操作,而其余进程组已经加入,就会发生死锁。这导致任务无限期地等待进展。
顾名思义,飞行记录器 (Flight Recorder) 在集合通信操作运行时捕获诊断信息。所捕获的诊断信息有助于在任务卡死时识别问题的根本原因。飞行记录器由两个核心部分组成:
收集部分:启用后,有关集合通信的信息会被记录在内存中的循环缓冲区中。任务超时或按需时,可以检索或将内存缓冲区转储到文件。
- 分析器脚本:在 torch/distributed/flight_recorder 目录中提供了一个分析器脚本(详情见下文)。
分析器脚本利用收集到的数据运行已知的启发式算法,尝试自动识别导致任务停滞的根本问题。
启用飞行记录器#
要使初始版本的飞行记录器工作,需要设置三个环境变量。
TORCH_NCCL_TRACE_BUFFER_SIZE = (0, N):将N设置为一个正数即可启用收集功能。N表示循环缓冲区内部保留的条目数。建议将此值设置为 2000。默认值为2000。TORCH_NCCL_DUMP_ON_TIMEOUT = (true, false):设置为true时,任务超时会将诊断文件写入磁盘。如果启用,任务的运行目录中将为每个进程组输出一个文件。默认值为false。TORCH_FR_DUMP_TEMP_FILE:设置飞行记录器转储文件的路径及文件前缀。每个进程组一个文件。默认值为/tmp/nccl_trace_rank_。
可选设置
TORCH_NCCL_TRACE_CPP_STACK = (true, false):设置为 true 可在飞行记录器中捕获 C++ 堆栈跟踪。C++ 堆栈跟踪有助于提供从 PyTorch Python 调用到原始 C++ 实现的精确代码路径。另请参阅附加设置中的TORCH_SYMBOLIZE_MODE。TORCH_NCCL_ENABLE_TIMING = (true, false):设置为true将在每个集合通信开始时启用额外的 CUDA 事件,并记录每个集合通信的持续时间。这可能会产生一些 CPU 开销。在收集的数据中,duration 字段表示每个集合通信执行所需的时间。
附加设置#
TORCH_SYMBOLIZE_MODE = (dladdr, addr2line, fast):此设置确定用于从运行中的程序检索 C++ 跟踪的程序。默认设置为
addr2line。fast是一种新的实验模式,已被证明比传统的addr2line快得多。将此设置与TORCH_NCCL_TRACE_CPP_STACK结合使用,可在飞行记录器数据中收集 C++ 跟踪信息。
如果你不想将飞行记录器数据转储到本地磁盘,而是希望保存到自己的存储中,可以定义自己的编写器 (writer) 类。该类应继承自
::c10d::DebugInfoWriter(代码),然后在启动 PyTorch 分布式之前,使用::c10d::DebugInfoWriter::registerWriter(代码) 注册新编写器。
通过 API 检索飞行记录器数据#
你也可以通过 API 调用检索飞行记录器数据。带有默认参数的 API 如下所示:
torch._C._distributed_c10d._dump_nccl_trace(includeCollectives=True, includeStackTraces=True, onlyActive=False)
要查看数据,你可以使用 unpickle,如下所示:
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
print(t)
飞行记录器文件格式#
飞行记录器文件以 pickle 格式转储。文件被写入本地磁盘或挂载的共享 NFS 文件夹。
飞行记录器 unpickled 文件内容如下所示:
{
"version": "2.5",
"pg_config": {
"0": {
"name": "0",
"desc": "default_pg",
"ranks": "[0, 1]"
}
},
"pg_status": {
"0": {
"last_enqueued_collective": 2,
"last_started_collective": -1,
"last_completed_collective": 2
}
},
"entries": [
{
"frames": [
{
"name": "test_short_pickle",
"filename": "pytorch/test/distributed/test_c10d_nccl.py",
"line": 3647
},
{
"name": "spawn_main",
"filename": ".conda/envs/pytorch-3.10/lib/python3.10/multiprocessing/spawn.py",
"line": 116
},
{
"name": "<module>",
"filename": "<string>",
"line": 1
}
],
"record_id": 0,
"pg_id": 0,
"process_group": ("0", "default_pg"),
"collective_seq_id": 1,
"p2p_seq_id": 0,
"op_id": 1,
"profiling_name": "nccl:all_reduce",
"time_created_ns": 1724779239936775119,
"input_sizes": [[3, 4]],
"input_dtypes": ["Float"],
"output_sizes": [[3, 4]],
"output_dtypes": ["Float"],
"state": "completed",
"time_discovered_started_ns": null,
"time_discovered_completed_ns": 1724779239975811724,
"retired": true,
"timeout_ms": 600000,
"is_p2p": false
},
...
]
}
分析飞行记录器转储文件#
我们在 pytorch/torch/distributed/flight_recorder 目录中提供了方便的脚本,用于分析捕获的数据。
要运行该方便脚本,请遵循以下步骤:
将来自一个进程组的所有文件复制到同一个目录中。
要运行脚本,请使用此命令:
python fr_trace.py <dump dir containing trace files> [-o <output file>]
如果你安装了 PyTorch 夜间构建版,或者使用 USE_DISTRIBUTED=1 从源码构建,你可以直接使用以下命令:
torchfrtrace <dump dir containing trace files> [-o <output file>]
目前,我们支持分析器脚本的两种模式。第一种模式允许脚本对解析后的飞行记录器转储文件应用一些启发式算法,生成一份报告,指出导致超时的潜在罪魁祸首。第二种模式仅仅是输出原始转储文件。默认情况下,脚本会打印所有进程组和所有 ``ProcessGroups`` (PGs) 的飞行记录器转储信息。可以使用 –selected-ranks 参数指定进程组,使用 –pg-filters 参数指定 PGs,从而缩小范围。示例命令如下:
警告:需要 tabulate 模块,因此你可能需要先运行 pip install 进行安装。
python fr_trace.py <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters tp dp]
torchfrtrace <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters 0 2]
端到端示例#
为了演示飞行记录器的使用,我们将使用一个小程序来模拟不匹配的集合通信。在此示例中,rank0 被编程执行额外的集合通信。飞行记录器转储文件被保存到 /tmp 目录中。为了演示,我们将此程序命名为 crash.py。
注意
请注意,这是一个简化示例。在现实场景中,流程会涉及更多复杂性。
import torch
import torch.distributed as dist
import os
from datetime import timedelta
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size <= 8, "world size must be less than or equal to 8"
os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = "/tmp/trace_"
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"
device = torch.device(f"cuda:{local_rank}")
print(f"{local_rank=} {world_size=} master addr: {os.environ['MASTER_ADDR']} master port: {os.environ['MASTER_PORT']} {device=}")
# Initialize the process group with a small timeout so that jobs fail quickly
dist.init_process_group("nccl", world_size=world_size, rank=local_rank, timeout=timedelta(seconds=1))
a = torch.full((3, 4), float(local_rank), device=device)
# Write some collectives to populate Flight Recorder data
for i in range(2):
print(f"calling allreduce on {local_rank=}")
f = dist.all_reduce(a)
# rank0 is doing an additional collective
if local_rank == 0:
print("rank0 is doing an allreduce on tensor b, but other ranks forgot")
b = torch.full((4,5), float(local_rank), device=device)
f = dist.all_reduce(b)
for i in range(2):
print(f"calling allreduce on {local_rank=}")
f = dist.all_reduce(a)
torch.cuda.synchronize(device=device)
print(f"{local_rank=} exiting")
要运行此程序,请使用 torchrun:
torchrun --nnodes=1 --nproc_per_node=2 crash.py
你应该会在 /tmp 目录中看到两个文件:
$ls /tmp/trace*
# Expected output
/tmp/trace_0 /tmp/trace_1
最后,要分析这两个文件,我们使用 torchfrtrace 命令:
torchfrtrace --prefix "trace_" /tmp/
trace 命令的输出旨在便于人工阅读。它包含有关导致故障的集合通信集的信息。上述命令的输出如下所示。我们可以清楚地看到,rank 1 没有加入 “all_reduce” 集合通信。
结论#
在本教程中,我们了解了一个名为飞行记录器 (Flight Recorder) 的新 PyTorch 诊断工具。我们讨论了如何启用飞行记录器以从机器收集诊断数据。此外,我们还探索了如何使用位于 PyTorch 仓库 torch/distributed/flight_recorder 目录中的便利脚本来分析从飞行记录器捕获的数据。