• 文档 >
  • PyTorch XLA 中的 Fully Sharded Data Parallel
快捷方式

PyTorch XLA 中的 Fully Sharded Data Parallel

PyTorch XLA 中的 Fully Sharded Data Parallel (FSDP) 是一种在数据并行工作器之间分片 Module 参数的实用工具。

使用示例

import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

也可以单独分片各个层,并让外部包装器处理任何剩余的参数。

注意:XlaFullyShardedDataParallel 类支持 ZeRO-2 优化器(分片梯度和优化器状态)以及 ZeRO-3 优化器(分片参数、梯度和优化器状态),参考 https://arxiv.org/abs/1910.02054。ZeRO-3 优化器应通过嵌套 FSDP 实现,并设置 reshard_after_forward=True。示例请参见 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py。 * 对于无法放入单个 TPU 内存或主机 CPU 内存的大型模型,应在子模块构造时交错内部 FSDP 包装。示例请参见 FSDPViTModel。提供了一个简单的包装器 checkpoint_module(基于 https://github.com/pytorch/xla/pull/3524 中的 torch_xla.utils.checkpoint.checkpoint)来对给定的 nn.Module 实例执行 梯度检查点。示例请参见 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py。自动包装子模块:除了手动嵌套 FSDP 包装外,还可以指定 auto_wrap_policy 参数来自动使用内部 FSDP 包装子模块。 torch_xla.distributed.fsdp.wrap 中的 size_based_auto_wrap_policy 是一个 auto_wrap_policy 可调用对象的示例,该策略包装参数量大于 1 亿的层。 torch_xla.distributed.fsdp.wrap 中的 transformer_auto_wrap_policy 是一个针对类似 Transformer 的模型架构的 auto_wrap_policy 可调用对象的示例。

例如,要自动包装所有 torch.nn.Conv2d 子模块并使用内部 FSDP,可以使用

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

此外,还可以指定 auto_wrapper_callable 参数来为子模块使用自定义的可调用包装器(默认包装器就是 XlaFullyShardedDataParallel 类本身)。例如,可以使用以下方法对每个自动包装的子模块应用梯度检查点(即激活检查点/重计算)。

from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
    checkpoint_module(m), *args, **kwargs)
  • 在步进优化器时,直接调用 optimizer.step,不要调用 xm.optimizer_step。后者会跨进程规约梯度,这对于 FSDP(参数已分片)是不需要的。

  • 在训练期间保存模型和优化器检查点时,每个训练进程需要保存其自己的(分片)模型和优化器状态字典的检查点(使用 master_only=False,并在 xm.save 中为每个进程设置不同的路径)。恢复时,需要加载相应进程的检查点。

  • 请同时保存 model.get_shard_metadata()model.state_dict(),如下所示,并使用 consolidate_sharded_model_checkpoints 将分片模型检查点拼接成完整的模型状态字典。示例请参见 test/test_train_mp_mnist_fsdp_with_ckpt.py

ckpt = {
    'model': model.state_dict(),
    'shard_metadata': model.get_shard_metadata(),
    'optimizer': optimizer.state_dict(),
}
ckpt_path = f'/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth'
xm.save(ckpt, ckpt_path, master_only=False)
  • 也可以从命令行启动检查点合并脚本,如下所示。

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /path/to/your_sharded_checkpoint_files \
  --ckpt_suffix "_rank-*-of-*.pth"

该类的实现很大程度上受到 https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.htmlfairscale.nn.FullyShardedDataParallel 的启发,并主要遵循其结构。与 fairscale.nn.FullyShardedDataParallel 相比,最大的区别之一是,在 XLA 中我们没有显式的参数存储,因此这里我们采取了不同的方法来为 ZeRO-3 释放完整参数。

MNIST 和 ImageNet 的示例训练脚本

安装

FSDP 在 PyTorch/XLA 1.12 版本及更新的 nightly 版本中可用。有关安装指南,请参考 https://github.com/pytorch/xla#-available-images-and-wheels

克隆 PyTorch/XLA 仓库

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/

在 v3-8 TPU 上训练 MNIST

2 个 epoch 约为 98.9 的准确率

python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --use_gradient_checkpointing

该脚本在最后会自动测试检查点合并。您也可以通过以下方式手动合并分片检查点:

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"

在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet

100 个 epoch 约为 75.9 的准确率;将 ImageNet-1k 下载到 /datasets/imagenet-1k

python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
  --use_nested_fsdp

您还可以添加 --use_gradient_checkpointing(需要与 --use_nested_fsdp--auto_wrap_policy 一起使用)来对残差块应用梯度检查点。

TPU Pod(具有 100 亿参数)上的示例训练脚本

要训练无法放入单个 TPU 的大型模型,在构建整个模型时应应用自动包装或手动包装子模块与内部 FSDP,以实现 ZeRO-3 算法。

请参阅 https://github.com/ronghanghu/vit_10b_fsdp_example,了解使用此 XLA FSDP PR 训练 Vision Transformer (ViT) 模型的示例。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源