评价此页

Fully Sharded Data Parallel (FSDP2) 入门#

创建日期:2022年3月17日 | 最后更新:2025年9月2日 | 最后验证:2024年11月5日

作者: Wei Feng, Will Constable, Yifan Mao

注意

edit 请在 pytorch/examples 中查看本教程的代码。FSDP1 已弃用。FSDP1 教程已归档于 [1][2]

FSDP2 的工作原理#

DistributedDataParallel (DDP) 训练中,每个 rank 都拥有一个模型副本并处理一批数据,最后使用 all-reduce 来同步各 rank 间的梯度。

与 DDP 相比,FSDP 通过分片(sharding)模型参数、梯度和优化器状态来减少 GPU 内存占用。这使得训练无法放入单个 GPU 的模型成为可能。如下图所示:

  • 在前向和后向计算之外,参数被完全分片

  • 在前向和后向计算之前,分片后的参数通过 all-gather 聚合成未分片的参数

  • 在后向计算内部,局部的未分片梯度通过 reduce-scatter 聚合成分片后的梯度

  • 优化器使用分片后的梯度更新分片后的参数,产生分片后的优化器状态

FSDP workflow

FSDP 可以被看作是将 DDP 的 all-reduce 操作分解为 reduce-scatter 和 all-gather 操作

FSDP all-gather and reduce-scatter

FSDP1 相比,FSDP2 具有以下优势:

  • 将分片参数表示为在 dim-i 上分片的 DTensor,从而能够轻松操作单个参数、实现无需通信的分片 state dict,以及更简单的 meta-device 初始化流程。

  • 改进了内存管理系统,通过避免 recordStream (文档) 实现了更低且确定性的 GPU 内存使用,且无需任何 CPU 同步。

  • 提供了一个张量子类扩展点以自定义 all-gather,例如用于 float8 线性层的 float8 all-gather (文档),以及用于 QLoRA 的 NF4 (文档)。

  • 可以在同一个通信组中混合使用冻结和非冻结参数,而无需额外内存。

如何使用 FSDP2#

模型初始化#

在子模块上应用 fully_shard:与 DDP 不同,我们应该对子模块以及根模型应用 fully_shard。在下面的 transformer 示例中,我们首先对每一层应用 fully_shard,然后对根模型应用。

  • layers[i] 的前向计算期间,其余层保持分片状态以减少内存占用

  • fully_shard(model) 内部,FSDP2 会将 model.layers 中的参数排除在外,并将剩余参数分类到参数组中,以便进行高效的 all-gather 和 reduce-scatter

  • fully_shard 会将分片后的模型移动到实际的训练设备(例如 cuda

命令: torchrun --nproc_per_node 2 train.py

from torch.distributed.fsdp import fully_shard, FSDPModule
model = Transformer()
for layer in model.layers:
    fully_shard(layer)
fully_shard(model)

assert isinstance(model, Transformer)
assert isinstance(model, FSDPModule)
print(model)
#  FSDPTransformer(
#    (tok_embeddings): Embedding(...)
#    ...
#    (layers): 3 x FSDPTransformerBlock(...)
#    (output): Linear(...)
#  )

我们可以通过 print(model) 来检查嵌套包装情况。FSDPTransformerTransformerFSDPModule 的联合类。FSDPTransformerBlock 也是如此。所有 FSDP2 公共 API 都通过 FSDPModule 公开。例如,用户可以调用 model.unshard() 来手动控制 all-gather 调度。详情请参阅下文的“显式预取”。

model.parameters() 作为 DTensor: fully_shard 在各 rank 间分片参数,并将 model.parameters() 从普通的 torch.Tensor 转换为表示分片参数的 DTensor。FSDP2 默认在 dim-0 上分片,因此 DTensor 的放置策略为 Shard(dim=0)。假设有 N 个 rank,且一个参数在分片前有 N 行。分片后,每个 rank 将拥有该参数的 1 行。我们可以使用 param.to_local() 检查分片后的参数。

from torch.distributed.tensor import DTensor
for param in model.parameters():
    assert isinstance(param, DTensor)
    assert param.placements == (Shard(0),)
    # inspect sharded parameters with param.to_local()

optim = torch.optim.Adam(model.parameters(), lr=1e-2)

注意,优化器是在应用 fully_shard 之后构建的。模型和优化器的 state dict 都以 DTensor 表示。

DTensor 促进了优化器、梯度裁剪和检查点保存

  • torch.optim.Adamtorch.nn.utils.clip_grad_norm_ 可直接用于 DTensor 参数。这使得单设备训练和分布式训练之间的代码保持一致。

  • 我们可以使用 DTensor 和 DCP API 来操作参数以获取完整 state dict,详情请参阅下文的“state dict”章节。对于分布式 state dict,我们可以保存/加载检查点 (文档),而无需额外通信。

前向/后向与预取#

命令: torchrun --nproc_per_node 2 train.py

for _ in range(epochs):
    x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
    loss = model(x).sum()
    loss.backward()
    optim.step()
    optim.zero_grad()

fully_shard 注册前向/后向钩子,以便在计算前 all-gather 参数,并在计算后重新分片。为了将 all-gather 与计算重叠,FSDP2 提供了隐式预取(开箱即用)和显式预取(供高级用户手动控制 all-gather 调度)。

隐式预取:CPU 线程在 layer i 之前发出 all-gather i。All-gather 被排入其自己的 cuda 流中,而 layer i 的计算在默认流中进行。对于非 CPU 密集型的工作负载(例如大 Batch Size 的 Transformer),all-gather i+1 可以与 layer i 的计算重叠。隐式预取在后向计算中工作方式类似,除了 all-gather 的发出顺序与前向相反。

FSDP Implicit

我们建议用户从隐式预取开始,以了解开箱即用的性能。

显式预取:用户可以使用 set_modules_to_forward_prefetch 指定前向顺序,使用 set_modules_to_backward_prefetch 指定后向顺序。如下面的代码所示,CPU 线程在 layer i 处发出 all-gather i + 1 和 i + 2。

显式预取在以下情况下效果良好:

CPU 密集型工作负载:如果使用隐式预取,当 layer i 的内核执行时,CPU 线程可能太慢而无法及时发出 layer i+1 的 all-gather。我们必须在运行 layer i 的前向计算之前显式发出 all-gather i+1。

2层及以上的预取:隐式预取一次只 all-gather 下一层,以保持最小内存占用。使用显式预取可以一次 all-gather 多个层,通过增加内存占用可能获得更好的性能。请参阅代码中的 layers_to_prefetch

更早发出第一次 all-gather:隐式预取发生在调用 model(x) 时,第一次 all-gather 会被暴露出来。我们可以更早显式调用 model.unshard() 来更早发出第一次 all-gather。

命令: torchrun --nproc_per_node 2 train.py --explicit-prefetching

num_to_forward_prefetch = 2
for i, layer in enumerate(model.layers):
    if i >= len(model.layers) - num_to_forward_prefetch:
        break
    layers_to_prefetch = [
        model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
    ]
    layer.set_modules_to_forward_prefetch(layers_to_prefetch)

num_to_backward_prefetch = 2
for i, layer in enumerate(model.layers):
    if i < num_to_backward_prefetch:
        continue
    layers_to_prefetch = [
        model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
    ]
    layer.set_modules_to_backward_prefetch(layers_to_prefetch)

for _ in range(epochs):
    # trigger 1st all-gather earlier
    # this overlaps all-gather with any computation before model(x)
    model.unshard()
    x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
    loss = model(x).sum()
    loss.backward()
    optim.step()
    optim.zero_grad()

启用混合精度#

FSDP2 提供了灵活的 混合精度策略 来加速训练。一个典型的用例是:

  • 将 float32 参数转换为 bfloat16 以进行前向/后向计算,请参阅 param_dtype=torch.bfloat16

  • 将梯度提升到 float32 以进行 reduce-scatter 以保持精度,请参阅 reduce_dtype=torch.float32

torch.amp 相比,FSDP2 混合精度具有以下优势:

  • 高效且灵活的参数转换FSDPModule 内的所有参数都在模块边界(前向/后向之前和之后)一起转换。我们可以为每一层设置不同的混合精度策略。例如,前几层可以是 float32,而其余层可以是 bfloat16。

  • float32 梯度归约 (reduce-scatter):不同 rank 之间的梯度可能差异很大。在 float32 中对梯度进行归约对于数值稳定性至关重要。

命令: torchrun --nproc_per_node 2 train.py --mixed-precision

model = Transformer(model_args)
fsdp_kwargs = {
    "mp_policy": MixedPrecisionPolicy(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
    )
}
for layer in model.layers:
    fully_shard(layer, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)

# sharded parameters are float32
for param in model.parameters():
    assert param.dtype == torch.float32

# unsharded parameters are bfloat16
model.unshard()
for param in model.parameters(recurse=False):
    assert param.dtype == torch.bfloat16
model.reshard()

# optimizer states are in float32
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

# training loop
# ...

使用 DTensor 进行梯度裁剪和优化器操作#

命令: torchrun --nproc_per_node 2 train.py

# optim is constructed base on DTensor model parameters
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
for _ in range(epochs):
    x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
    loss = model(x).sum()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
    optim.step()
    optim.zero_grad()

优化器在模型应用 fully_shard 后初始化,并持有对 DTensor model.parameters() 的引用。对于梯度裁剪,torch.nn.utils.clip_grad_norm_ 可用于 DTensor 参数。张量运算将在 DTensor 内部正确调度,以在各 rank 间通信部分张量,从而保持单设备语义。

使用 DTensor API 处理 State Dicts#

我们展示如何将完整 state dict 转换为 DTensor state dict 以进行加载,以及如何将其转换回完整 state dict 以进行保存。

命令: torchrun --nproc_per_node 2 train.py

  • 第一次时,为模型和优化器创建检查点

  • 第二次时,从之前的检查点加载以恢复训练

加载 state dicts:我们在 meta 设备下初始化模型并调用 fully_shard,将 model.parameters() 从普通 torch.Tensor 转换为 DTensor。读取 torch.load 的完整 state dict 后,我们可以调用 distribute_tensor,使用来自 model.state_dict() 的相同放置策略和设备网格,将普通 torch.Tensor 转换为 DTensor。最后,我们调用 model.load_state_dict 将 DTensor state dicts 加载到模型中。

from torch.distributed.tensor import distribute_tensor

# mmap=True reduces CPU memory usage
full_sd = torch.load(
    "checkpoints/model_state_dict.pt",
    mmap=True,
    weights_only=True,
    map_location='cpu',
)
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
    sharded_meta_param = meta_sharded_sd.get(param_name)
    sharded_tensor = distribute_tensor(
        full_tensor,
        sharded_meta_param.device_mesh,
        sharded_meta_param.placements,
    )
    sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# `assign=True` since we cannot call `copy_` on meta tensor
model.load_state_dict(sharded_sd, assign=True)

保存 state dictsmodel.state_dict() 返回一个 DTensor state dict。我们可以通过调用 full_tensor() 将 DTensor 转换为普通的 torch.Tensor。在内部,它会在各 rank 间发出 all-gather 以获取普通 torch.Tensor 中的未分片参数。对于 rank 0,full_param.cpu() 会逐个将张量卸载到 CPU,以避免未分片参数造成 GPU 内存峰值。

sharded_sd = model.state_dict()
cpu_state_dict = {}
for param_name, sharded_param in sharded_sd.items():
    full_param = sharded_param.full_tensor()
    if torch.distributed.get_rank() == 0:
        cpu_state_dict[param_name] = full_param.cpu()
    else:
        del full_param
torch.save(cpu_state_dict, "checkpoints/model_state_dict.pt")

优化器 state dict 的工作方式类似(代码)。用户可以自定义上述 DTensor 脚本以配合第三方检查点使用。

如果没有自定义需求,我们可以直接使用 DCP API 来支持单节点和多节点训练。

使用 DCP API 处理 State Dict#

命令: torchrun --nproc_per_node 2 train.py --dcp-api

  • 第一次时,为模型和优化器创建检查点

  • 第二次时,从之前的检查点加载以恢复训练

加载 state dicts:我们可以使用 set_model_state_dict 将完整 state dict 加载到 FSDP2 模型中。通过设置 broadcast_from_rank0=True,我们只需在 rank 0 上加载完整 state dict,从而避免 CPU 内存峰值。DCP 将对张量进行分片并将其广播到其他 rank。

from torch.distributed.checkpoint.state_dict import set_model_state_dict
set_model_state_dict(
    model=model,
    model_state_dict=full_sd,
    options=StateDictOptions(
        full_state_dict=True,
        broadcast_from_rank0=True,
    ),
)

保存 state dicts:使用 full_state_dict=Truecpu_offload=Trueget_model_state_dict 会对张量进行 all-gather 并将其卸载到 CPU。其工作方式与 DTensor API 类似。

from torch.distributed.checkpoint.state_dict import get_model_state_dict
model_state_dict = get_model_state_dict(
    model=model,
    options=StateDictOptions(
        full_state_dict=True,
        cpu_offload=True,
    )
)
torch.save(model_state_dict, "model_state_dict.pt")

有关使用 set_optimizer_state_dictget_optimizer_state_dict 加载和保存优化器 state dict 的信息,请参阅 pytorch/examples

FSDP1 到 FSDP2 迁移指南#

让我们看一个 FSDP 用法示例及其等效的 fully_shard 用法。我们将重点介绍主要区别并提出迁移建议。

原始 FSDP() 用法

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with torch.device("meta"):
    model = Transformer()
policy = ModuleWrapPolicy({TransformerBlock})
model = FSDP(model, auto_wrap_policy=policy)
def param_init_fn(module: nn.Module) -> None: ...
model = FSDP(model, auto_wrap_policy=policy, param_init_fn=param_init_fn)

新的 fully_shard() 用法

with torch.device("meta"):
    model = Transformer()
for module in model.modules():
    if isinstance(module, TransformerBlock):
        fully_shard(module)
fully_shard(model)
for tensor in itertools.chain(model.parameters(), model.buffers()):
    assert tensor.device == torch.device("meta")


# Initialize the model after sharding
model.to_empty(device="cuda")
model.reset_parameters()

迁移步骤

  • 替换 import

  • 直接实现您的“策略”(将 fully_shard 应用于所需的子层)

  • fully_shard 包装根模型,而不是 FSDP

  • 去掉 param_init_fn 并手动调用 model.reset_parameters()

  • 替换其他 FSDP1 关键字参数(见下文)

sharding_strategy

  • FULL_SHARD: reshard_after_forward=True

  • SHARD_GRAD_OP: reshard_after_forward=False

  • HYBRID_SHARD: reshard_after_forward=True (使用二维设备网格)

  • _HYBRID_SHARD_ZERO2: reshard_after_forward=False (使用二维设备网格)

cpu_offload

  • CPUOffload.offload_params=False: offload_policy=None

  • CPUOffload.offload_params = True: offload_policy=CPUOffloadPolicy()

backward_prefetch

  • BACKWARD_PRE: 始终使用

  • BACKWARD_POST: 不支持

mixed_precision

  • buffer_dtype 被省略,因为 fully_shard 不会对缓冲区进行分片

  • fully_shard 的 cast_forward_inputs 对应于 FSDP1 中的 cast_forward_inputscast_root_forward_inputs

  • output_dtype 是 fully_shard 的新配置

device_id: 从 device_mesh 的设备中推断

sync_module_states=True/False: 已移至 DCP。用户可以使用带有 broadcast_from_rank0=Trueset_model_state_dict 从 rank0 广播 state dict

forward_prefetch: 可通过以下方式对手动预取进行控制

limit_all_gathers: 不再需要,因为 fully_shard 移除了 CPU 同步

use_orig_params: 始终使用原始参数(不再有扁平化参数)

no_sync(): set_requires_gradient_sync

ignored_params 和 ignored_states: ignored_params