使用 SPMD 的完全分片数据并行 (FSDP)¶
PyTorch/XLA 中的 FSDP 是一种将 Module 参数分片到数据并行工作器上的实用程序。
这与 PyTorch/XLA 中 FSDP 的其他实现不同之处在于,此实现使用了 SPMD。
在继续之前,请先查看 SPMD 用户指南 中的 SPMD 用户指南。您也可以在这里找到一个可运行的最小示例:here。
使用示例
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
# Define the mesh following common SPMD practice
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'model'))
# Shard the input, and assume x is a 2D tensor.
x = xs.mark_sharding(x, mesh, ('fsdp', None))
# As normal FSDP, but an extra mesh is needed.
model = FSDPv2(my_module, mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
也可以单独分片各个层,并由一个外部包装器处理任何剩余的参数。以下是一个自动包装每个 DecoderLayer
的示例。
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
# Apply FSDP sharding on each DecoderLayer layer.
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
decoder_only_model.DecoderLayer
},
)
model = FSDPv2(
model, mesh=mesh, auto_wrap_policy=auto_wrap_policy)
梯度检查点¶
目前,需要在 FSDP 包装器之前将梯度检查点应用于模块。否则,递归地进入子模块会导致无限循环。我们将在未来的版本中修复此问题。
使用示例
from torch_xla.distributed.fsdp import checkpoint_module
model = FSDPv2(checkpoint_module(my_module), mesh)