评价此页

用于调试卡住作业的飞行记录器#

作者: Chirag Pandya, Junjie Wang

你将学到什么#

  • 了解一种用于调试分布式训练期间卡住作业的新工具。

  • 了解如何启用该工具以及如何使用收集到的数据来分析卡住的作业。

先决条件#

  • PyTorch 版本 2.5 或更高版本。

  • tabulate。你可以通过运行 pip install tabulate 来安装。

概述#

AI 分布式训练作业是指使用网络连接的多个设备(如 GPU 或 CPU)训练机器学习模型的过程。这种方法可以更快、更有效地训练需要大量计算资源的大型模型。工程师的目标是尽快完成 AI 训练作业并进行持续改进,以便后续训练可以更快地完成。训练有素、可用的模型是最终期望的成果。完成训练的最大障碍之一就是“*卡住的作业*”的概念。

当分布式 AI 训练作业在很长一段时间内停止取得有意义的进展时,就被认为是卡住的

作业可能因各种原因而卡住

  • 数据饥饿: 当训练作业接收数据的速率未达到预期时发生,可能是由于数据管道或数据源的问题。

  • 资源限制: 如果运行作业的系统没有足够的计算资源(如 CPU、GPU 或内存),作业可能无法继续。

  • 网络问题: 在分布式训练设置中,模型的不同部分或数据可能在不同的设备上处理。如果存在网络问题,这些设备之间的通信可能会中断,导致作业卡住。

  • 软件错误或缺陷: 训练代码或底层库和框架中的错误也可能导致作业卡住。

  • 同步问题: 在分布式训练中,计算的不同部分通常是并行运行的,并且需要在某些点进行同步。如果同步失败,作业就会卡住。例如,如果一个或多个 rank 未能加入一个集体操作,而其余 rank 已加入,则可能发生死锁。这将导致作业无限期等待才能继续。

正如其名,飞行记录器在集体操作运行时捕获诊断信息。捕获的诊断信息用于帮助识别作业卡住时问题的根本原因。飞行记录器包含两个核心部分:

  • 收集部分:启用后,有关集体操作的信息将被记录在内存中的循环缓冲区中。作业超时或按需时,可以检索内存缓冲区或将其转储到文件。

  • 分析器脚本位于 tools/flight_recorder 目录中(下文详述)。

    分析器脚本使用收集到的数据运行已知的启发式方法,并尝试自动识别导致作业停滞的潜在问题。

启用飞行记录器#

要使飞行记录器的初始版本正常工作,需要设置三个环境变量。

  • TORCH_NCCL_TRACE_BUFFER_SIZE = (0, N):将 N 设置为正数即可启用收集。 N 表示将在内部循环缓冲区中保留的条目数。我们建议将此值设置为 *2000*。默认值为 2000

  • TORCH_NCCL_DUMP_ON_TIMEOUT = (true, false):将其设置为 true 将在作业超时时将诊断文件写入磁盘。如果启用,作业运行目录中将为每个 rank 输出一个文件。默认值为 false

  • TORCH_FR_DUMP_TEMP_FILE:设置飞行记录器将转储到的文件路径(带文件前缀)。每个 rank 一个文件。默认值为 /tmp/nccl_trace_rank_

可选设置

  • TORCH_NCCL_TRACE_CPP_STACK = (true, false):将其设置为 true 可在飞行记录器中捕获 C++ 堆栈跟踪。C++ 堆栈跟踪有助于提供从 PyTorch Python 调用到 C++ 原始实现的精确代码路径。另请参阅其他设置中的 TORCH_SYMBOLIZE_MODE

  • TORCH_NCCL_ENABLE_TIMING = (true, false):将其设置为 true 将在每个集体操作的开始时启用额外的 CUDA 事件,并记录每个集体操作的*持续时间*。这可能会产生一些 CPU 开销。在收集到的数据中,*duration* 字段表示每个集体操作执行所花费的时间。

其他设置#

  • TORCH_SYMBOLIZE_MODE = (dladdr, addr2line, fast):此设置确定用于从正在运行的程序检索 C++ 跟踪的程序。

    默认设置为 addr2line

    fast 是一种新的实验模式,显示比传统的 addr2line 快得多。请将此设置与 TORCH_NCCL_TRACE_CPP_STACK 结合使用,以在飞行记录器数据中收集 C++ 跟踪。

  • 如果您不想将飞行记录器数据转储到本地磁盘,而是转储到自己的存储中,则可以定义自己的写入器类。该类应继承自 ::c10d::DebugInfoWriter(代码),然后使用 ::c10d::DebugInfoWriter::registerWriter (代码) 注册新写入器,然后再初始化 PyTorch 分布式。

通过 API 检索飞行记录器数据#

您也可以通过 API 调用检索飞行记录器数据。下面显示了带默认参数的 API

torch._C._distributed_c10d._dump_nccl_trace(includeCollectives=True, includeStackTraces=True, onlyActive=False)

要查看数据,您可以像下面这样 unpickle

t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
print(t)

飞行记录器文件格式#

飞行记录器文件以 pickle 格式转储。文件被写入本地磁盘或挂载的网络文件系统(NFS)文件夹。

下面显示了飞行记录器 unpickled 文件的内容

{
  "version": "2.5",
  "pg_config": {
    "0": {
    "name": "0",
    "desc": "default_pg",
    "ranks": "[0, 1]"
    }
  },
  "pg_status": {
    "0": {
    "last_enqueued_collective": 2,
    "last_started_collective": -1,
    "last_completed_collective": 2
    }
  },
  "entries": [
  {
    "frames": [
    {
    "name": "test_short_pickle",
    "filename": "pytorch/test/distributed/test_c10d_nccl.py",
    "line": 3647
    },
    {
    "name": "spawn_main",
    "filename": ".conda/envs/pytorch-3.10/lib/python3.10/multiprocessing/spawn.py",
    "line": 116
    },
    {
    "name": "<module>",
    "filename": "<string>",
    "line": 1
    }
    ],
    "record_id": 0,
    "pg_id": 0,
    "process_group": ("0", "default_pg"),
    "collective_seq_id": 1,
    "p2p_seq_id": 0,
    "op_id": 1,
    "profiling_name": "nccl:all_reduce",
    "time_created_ns": 1724779239936775119,
    "input_sizes": [[3, 4]],
    "input_dtypes": ["Float"],
    "output_sizes": [[3, 4]],
    "output_dtypes": ["Float"],
    "state": "completed",
    "time_discovered_started_ns": null,
    "time_discovered_completed_ns": 1724779239975811724,
    "retired": true,
    "timeout_ms": 600000,
    "is_p2p": false
    },
    ...
    ]
}

分析飞行记录器转储#

我们在 pytorch/tools/flight_recorder 目录中提供了方便的脚本来分析捕获的数据。

要运行方便的脚本,请按照以下步骤操作

  1. 将一个 rank 的所有文件复制到一个目录中。

  2. 要运行脚本,请使用此命令

python fr_trace.py <dump dir containing trace files> [-o <output file>]

如果您安装了 PyTorch 的 nightly 版本或使用 USE_DISTRIBUTED=1 从头开始构建,则可以直接使用以下命令

torchfrtrace <dump dir containing trace files> [-o <output file>]

目前,我们支持分析器脚本的两种模式。第一种模式允许脚本对解析后的飞行记录器转储应用一些启发式方法,以生成报告,识别可能导致超时的原因。第二种模式是简单地输出原始转储。默认情况下,脚本会打印所有 rank 和所有 ``ProcessGroups``(PGs)的飞行记录器转储。可以使用 `–selected-ranks` 参数指定 rank,使用 `–pg-filters` 参数指定 PGs,从而将范围缩小到特定的 rank 和 PGs。命令示例是

注意:需要 tabulate 模块,因此您可能需要先 `pip install` 它。

python fr_trace.py <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters tp dp]
torchfrtrace <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters 0 2]

端到端示例#

为了演示飞行记录器的使用,我们将使用一个程序,在该程序中人为制造不匹配的集体操作。在此示例中,rank0 被编程为执行额外的集体操作。飞行记录器转储文件保存在 /tmp 目录中。为方便演示,我们将此程序命名为 crash.py

注意

请注意,这是一个简化的示例。在实际场景中,过程会涉及更多复杂性。

import torch
import torch.distributed as dist
import os
from datetime import timedelta

local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size <= 8, "world size must be less than or equal to 8"
os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = "/tmp/trace_"
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"
device = torch.device(f"cuda:{local_rank}")
print(f"{local_rank=} {world_size=} master addr: {os.environ['MASTER_ADDR']} master port: {os.environ['MASTER_PORT']} {device=}")

# Initialize the process group with a small timeout so that jobs fail quickly
dist.init_process_group("nccl", world_size=world_size, rank=local_rank, timeout=timedelta(seconds=1))

a = torch.full((3, 4), float(local_rank), device=device)
# Write some collectives to populate Flight Recorder data
for i in range(2):
  print(f"calling allreduce on {local_rank=}")
  f = dist.all_reduce(a)

# rank0 is doing an additional collective
if local_rank == 0:
  print("rank0 is doing an allreduce on tensor b, but other ranks forgot")
  b = torch.full((4,5), float(local_rank), device=device)
  f = dist.all_reduce(b)

for i in range(2):
  print(f"calling allreduce on {local_rank=}")
  f = dist.all_reduce(a)

torch.cuda.synchronize(device=device)
print(f"{local_rank=} exiting")

要运行此程序,请使用 torchrun

torchrun --nnodes=1 --nproc_per_node=2 crash.py

您应该会在 /tmp 目录中看到两个文件

$ls /tmp/trace*
# Expected output
/tmp/trace_0 /tmp/trace_1

最后,要分析这两个文件,我们使用 torchfrtrace 命令

torchfrtrace --prefix "trace_" /tmp/

跟踪命令的输出旨在方便人类阅读。它包括有关导致失败的集体操作集的信息。上述命令的输出如下所示。我们可以清楚地看到 rank 1 未加入“all_reduce”集体操作。

结论#

在本教程中,我们了解了 PyTorch 的一种新诊断工具——飞行记录器。我们讨论了如何启用飞行记录器来从机器收集诊断数据。此外,我们还探讨了如何使用 PyTorch 存储库中 tools/flight_recorder 目录中的便利脚本来分析从飞行记录器捕获的数据。