完全分片数据并行 (FSDP) 和每个加速器一个进程¶
PyTorch/XLA 中的 FSDP 是一种在数据并行工作节点之间分片 Module 参数的实用工具。
这与 PyTorch/XLA 中 FSDP 的其他实现不同之处在于,此实现每个加速器运行一个进程。
使用示例
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
也可以单独分片各个层,并由一个外部包装器处理任何剩余的参数。
注意:XlaFullyShardedDataParallel
类支持 https://arxiv.org/abs/1910.02054 中的 ZeRO-2 优化器(分片梯度和优化器状态)和 ZeRO-3 优化器(分片参数、梯度和优化器状态)。ZeRO-3 优化器应通过嵌套 FSDP 和 reshard_after_forward=True
来实现。请参阅 test/test_train_mp_mnist_fsdp_with_ckpt.py
和 test/test_train_mp_imagenet_fsdp.py
作为示例。* 对于无法放入单个 TPU 内存或主机 CPU 内存的大型模型,应将子模块构建与内部 FSDP 包装交错进行。请参阅 FSDPViTModel 作为示例。提供了一个简单的包装器 checkpoint_module
(基于 https://github.com/pytorch/xla/pull/3524 中的 torch_xla.utils.checkpoint.checkpoint
),用于对给定的 nn.Module
实例执行梯度检查点。请参阅 test/test_train_mp_mnist_fsdp_with_ckpt.py
和 test/test_train_mp_imagenet_fsdp.py
作为示例。自动包装子模块:除了手动嵌套 FSDP 包装外,还可以指定 auto_wrap_policy
参数,以自动使用内部 FSDP 包装子模块。torch_xla.distributed.fsdp.wrap
中的 size_based_auto_wrap_policy
是一个 auto_wrap_policy
可调用对象的示例,该策略包装参数数量大于 100M 的层。torch_xla.distributed.fsdp.wrap
中的 transformer_auto_wrap_policy
是针对类似 Transformer 的模型架构的 auto_wrap_policy
可调用对象的示例。
例如,要使用内部 FSDP 自动包装所有 torch.nn.Conv2d
子模块,可以使用
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})
此外,还可以指定 auto_wrapper_callable
参数,为子模块使用自定义的可调用包装器(默认包装器只是 XlaFullyShardedDataParallel
类本身)。例如,可以使用以下方法对每个自动包装的子模块应用梯度检查点(即激活检查点/重构)。
from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
checkpoint_module(m), *args, **kwargs)
在步进优化器时,直接调用
optimizer.step
,不要调用xm.optimizer_step
。后者会在进程间缩减梯度,这对于 FSDP(其中参数已经分片)是不需要的。在训练过程中保存模型和优化器检查点时,每个训练进程都需要保存其自己的(分片)模型和优化器状态字典的检查点(在
xm.save
中使用master_only=False
并为每个进程设置不同的路径)。恢复时,需要加载对应进程的检查点。请同时保存
model.get_shard_metadata()
和model.state_dict()
,如下所示,并使用consolidate_sharded_model_checkpoints
将分片模型检查点缝合到完整的模型状态字典中。请参阅test/test_train_mp_mnist_fsdp_with_ckpt.py
作为示例。
ckpt = {
'model': model.state_dict(),
'shard_metadata': model.get_shard_metadata(),
'optimizer': optimizer.state_dict(),
}
ckpt_path = f'/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth'
xm.save(ckpt, ckpt_path, master_only=False)
检查点合并脚本也可以从命令行启动,如下所示。
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /path/to/your_sharded_checkpoint_files \
--ckpt_suffix "_rank-*-of-*.pth"
此类实现的很大一部分灵感来自于 https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html 中的 fairscale.nn.FullyShardedDataParallel
,并且在很大程度上遵循其结构。与 fairscale.nn.FullyShardedDataParallel
最大的区别之一是,在 XLA 中我们没有显式的参数存储,因此我们在此采用了不同的方法来为 ZeRO-3 释放完整的参数。
MNIST 和 ImageNet 上的示例训练脚本¶
MNIST:test/test_train_mp_mnist_fsdp_with_ckpt.py(它还测试了检查点合并)
ImageNet:test/test_train_mp_imagenet_fsdp.py
安装¶
FSDP 在 PyTorch/XLA 1.12 及更高版本的 nightly 版本中可用。请参阅 https://github.com/pytorch/xla#-available-images-and-wheels 获取安装指南。
克隆 PyTorch/XLA 仓库¶
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
在 v3-8 TPU 上训练 MNIST¶
2 个 epoch 可获得约 98.9 的准确率
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp --use_gradient_checkpointing
此脚本在最后自动测试检查点合并。您也可以手动合并分片检查点,通过
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet¶
100 个 epoch 可获得约 75.9 的准确率;将 ImageNet-1k 下载到 /datasets/imagenet-1k
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
--use_nested_fsdp
您还可以添加 --use_gradient_checkpointing
(需要与 --use_nested_fsdp
或 --auto_wrap_policy
一起使用)来对残差块应用梯度检查点。
TPU Pod(100 亿参数)上的示例训练脚本¶
要训练无法放入单个 TPU 的大型模型,在构建整个模型时应应用自动包装或手动包装子模块,以实现 ZeRO-3 算法。
请参阅 https://github.com/ronghanghu/vit_10b_fsdp_example,了解使用此 XLA FSDP PR 分片训练 Vision Transformer (ViT) 模型的示例。