评价此页

FSDP 笔记#

创建于:2024 年 1 月 21 日 | 最后更新于:2025 年 6 月 7 日

FSDP 预取细微差别#

为了使 forward 全局收集与 forward 计算重叠,有两种可能的机制:

  1. 隐式前向预取(始终启用)

  2. 显式前向预取(forward_prefetch=True

隐式 forward 预取是指依赖于从单独的 CUDA 流中发出全局收集操作,以允许全局收集与之前(从 CPU 角度看)发出的 forward 计算重叠。例如,如果我们有第 0 层全局收集 -> 第 0 层 forward 计算 -> 第 1 层全局收集 -> ……,那么即使 CPU 线程在其之后发出,第 1 层全局收集也可以与第 0 层 forward 计算重叠。(第一次全局收集将无法与任何操作重叠。)

显式 forward 预取是指更改 CPU 线程的发布顺序:例如,第 0 层全局收集 -> 第 1 层全局收集 -> 第 0 层 forward 计算 -> ……。在 eager 模式下,当仍在执行第 0 层时,通常无法知道下一层是哪一层(例如示例中的第 1 层)。因此,显式 forward 预取只应用于执行顺序在迭代之间固定的模型(我们有时称之为“静态图”)。不满足此约束的模型示例是 FLAVA)。

显式 forward 预取仅节省了发出层 forward 计算内核所需的时间,代价是下一个全局收集的输出张量必须在当前张量仍在使用时分配。通过在当前 forward 计算内核之前发出下一个全局收集,下一个全局收集可以在 GPU 上更早开始。对于大多数 LLM 工作负载,情况并非如此,因此没有理由启用 forward_prefetch=True

相比之下,对于 backward,我们必须使用显式 backward 预取,否则通信和计算将完全没有重叠。原因是我们将一个 NCCL 进程组用于全局收集和 reduce-scatter(部分原因是,在早期的 NCCL 版本中,在同一设备上的相同 rank 上同时使用多个 NCCL 进程组是不安全的)。单个 NCCL 进程组意味着单个内部 NCCL 流,在该流上 reduce-scatter 和全局收集串行运行。因此,除非我们明确重新排序 CPU 发布顺序为下一个全局收集 -> 当前 reduce-scatter,否则当前 reduce-scatter 将阻塞下一个全局收集,从而阻塞下一个 backward 计算,阻止当前 reduce-scatter 重叠。

通信负载大小#

在 FSDP 中,通信是:

  1. forward 中的参数全局收集

  2. backward 中的参数全局收集

  3. backward 中的梯度 reduce-scatter

如果使用激活检查点(checkpoint()),则没有额外的通信,因为参数在 backward 期间无论如何都会被预取。

在 FSDP 设计中,每个 rank 的通信负载确定如下:每次调用 FullyShardedDataParallel 都会创建一个通信组,该组由 module.parameters() 中的参数组成,但任何已分配给嵌套 FullyShardedDataParallel 实例的参数除外。例如,对于 Llama,如果您将 FullyShardedDataParallel 应用于每个 Transformer 块和根模块,那么每个 Transformer 块都有一个通信组,最后是包含初始嵌入和最终线性层的通信组。每个通信组对应于一次全局收集调用和一次 reduce-scatter 调用。通过这种方式,您如何应用 FullyShardedDataParallel 决定了通信大小。通常,将 FSDP 应用于每个 Transformer 块是 LLM 的一个很好的启发式方法,考虑到当前设计,很难做得更好。

让我们考虑一个例子,我们有一个基于 Transformer 的模型在 8 个 GPU 上分片,分片仅发生在 Transformer 块级别,每个 Transformer 块包含 1.6B 参数,参数为 fp32(每个 4 字节)。这意味着一旦分片,每个 Transformer 块将在每个 rank 上包含 0.2B 参数。

  • forward 传递将以 0.2*4 = 0.8GB 的块进行全局收集通信。

  • backward 传递将进行两次 0.8GB 的通信(1x 全局收集和 1x reduce-scatter)。

换句话说,将有 3 次通信,每次负载为 0.8GB。如果模型由 10 个 Transformer 块组成,则总共有 30 次通信,总计 30*0.8=24GB

形式化每个 rank 每次通信的负载大小为 total_transformer_block_params_in_B*dtype_bytes/num_gpus (GBs)。

请注意,在此示例中,我们未包含嵌入所需的额外通信,这也应计算在内。并且计算将取决于输入和输出嵌入是否绑定。如果它们未绑定,则通信将增加 2 倍。

FSDP 缓冲区大小#

首先,让我们介绍为通信分配的缓冲区。

forward 当前需要 2 倍的全局收集缓冲区大小。原因如下:

正如 FSDP 预取细微差别 中解释的,在显式 forward 预取(forward_prefetch=True)的情况下,例如第 0 层全局收集 -> 第 0 层前向计算 -> 第 1 层全局收集,需要 2 个全局收集大小的缓冲区,因为一个缓冲区用于当前 forward,而另一个用于预取。

虽然隐式 forward 预取(forward_prefetch=False,默认)在理论上只需要 1 个缓冲区,但实际上仍然是 2 倍的全局收集大小缓冲区。原因是,在扁平参数 FSDP 设计中,我们不从全局收集缓冲区中复制数据。用于计算的参数直接在全局收集缓冲区中视图化(事实上,“扁平参数”的主要好处正是这个原因)。在这种情况下,当“第 1 层全局收集”与“第 0 层前向计算”重叠时,“第 0 层前向计算”正在使用在“第 0 层全局收集”缓冲区中视图化的参数。

那么一个自然的问题是,什么时候会需要 forward_prefetch=False 呢?对于静态图模型(如大多数 LLM),有一个主要的、技术性的原因。更多的是,实际上,我们为一些受 CPU 限制的内部模型快速添加了这个选项,并且没有在单元测试中测试过它的每个代码路径,所以我们对其信心不足。forward_prefetching=False 可能更容易理解,因为我们不必检查记录的前向顺序作为可能的“故障模式”;模块的全局收集始终可以在其分析器跟踪中的自己的 record_function 标签下找到。

backward 目前至少需要 2 倍的全局收集缓冲区大小,并且可能更多。原因如下:

当前的 FSDP 设计使用 recordStream 来管理在一个流中生成并在另一个流中使用的分配,这可能导致比预期更多的内存使用。具体多多少是“非确定性”的,因为它取决于 GPU 内核计时相对于 CPU 的情况。limit_all_gathers=True 参数是对此的缓解措施——更多细节请参阅 FSDP & CUDACachingAllocator 中的讨论。

现有 FSDP 与自动梯度的工作方式

  • 现有 FSDP 全局收集 flat_param,它是自动梯度叶子。

  • 它调用 torch.split 以获取 flat_param 中与其组成原始参数对应的 1D 视图。

  • 它在每个 1D 分割上调用 torch.view 以视图回到 ND。

  • 这意味着在 backward 中,我们最终会得到 ViewBackward(ND -> 1D)和 SplitWithSizesBackward(这是一个连接)。特别是,每个单独的梯度都作为单独的分配计算,并且会发生显式连接以构建 reduce-scatter 输入缓冲区。这意味着在该峰值内存点,reduce-scatter 实际上需要 2 倍的缓冲区大小。

总而言之,对于 backward,它大约是 reduce-scatter 的 2 倍缓冲区大小,加上任何 recordStream 的影响。

其次,我们讨论额外的缓冲区。

一旦从所有 rank 收集了分片参数,它们需要一个额外的缓冲区,大小为 `total_transformer_block_params_in_B*dtype_bytes` 来存储完整参数——因此继续前面的例子,如果每个 Transformer 块是 1.6B 参数且参数为 fp32,那么它将是 `1.6*4=6.4GB` 的缓冲区。

并且需要两个这样的缓冲区,因为一个正在使用,另一个正在预取。

总而言之,我们有:

  1. 2 倍通信缓冲区,大小为 total_transformer_block_params_in_B*dtype_bytes/num_gpus

  2. 2 倍未分片 Transformer 块参数缓冲区,大小为 total_transformer_block_params_in_B*dtype_bytes

或者如果您一直遵循示例,则为:

  1. 2 * 1.6 * 4 / 8 = 1.6GB

  2. 2 * 1.6 * 4 = 12.8GB

总计 14.4GB

现在我们简要讨论一下嵌入会发生什么,因为我们之前在计算中忽略了它们。

根据我们讨论的规则,您在注释中从“通信缓冲区大小确定如下”开始,我们可以分析如下:

  • 假设我们将 FSDP 应用于根模块(例如 Transformer 类)。假设我们进一步将 FSDP 应用于每个 Transformer 块(例如 TransformerBlock 类)。

  • 最常见的情况是,嵌入和最终线性投影是根 Transformer 类的直接子项。

  • 根据我们的规则,这意味着嵌入和最终线性投影被分配给根 Transformer 的扁平参数。

  • 我们有另一个特殊规则,即根模块在前向传递后不会释放其参数,因为它们无论如何都会在反向传递中立即进行全局收集。

  • 综上所述,这意味着包含嵌入和最终投影的根模块的扁平参数在开始前向传递时进行全局收集,并保留在 GPU 内存中直到反向传递结束。

  • 如果嵌入和最终线性层没有权重绑定,那么我们_可以_进一步将 FSDP 应用于嵌入和最终线性层。对于权重绑定的参数,我们要求它们是同一扁平参数的一部分(否则会重复计算)。这将允许嵌入在使用后在前向传递中释放,并且仅在反向传递结束时进行全局收集。

  • 希望这能更好地说明——每个 FSDP 模块都会分配其 module.parameters 中的参数,但已分配给另一个嵌套 FSDP 模块的参数除外,并且 FSDP 模块的 forward 定义了其参数的“活跃”间隔。因此,嵌套的 nn.Module 结构会影响全局收集/释放调度,从而影响内存/吞吐量性能。