使用完全分片数据并行 (FSDP) 进行高级模型训练#
创建于:2024年10月31日 | 最后更新:2024年10月31日 | 最后验证:2024年11月5日
作者: Hamid Shojanazeri, Less Wright, Rohan Varma, Yanli Zhao
PyTorch 的完全分片数据并行模块:一个用于在数据并行工作进程之间分片模块参数的包装器。
数据并行工作进程。
PyTorch 1.12 或更高版本
阅读有关 FSDP API 的信息。
本教程作为 PyTorch 1.12 版本的一部分,介绍了完全分片数据并行 (FSDP) 的更高级功能。要熟悉 FSDP,请参阅 FSDP 入门教程。
在本教程中,我们将以文本摘要任务为例,使用 FSDP 微调 HuggingFace (HF) T5 模型。
该示例使用了 Wikihow 数据集,为简单起见,我们将展示在单节点、配备 8 个 A100 GPU 的 P4dn 实例上进行训练。我们现在有几篇博客文章( (链接1), (链接2))和一篇关于在多节点集群上进行大规模 FSDP 训练的 论文。
FSDP 是一个已投入生产的软件包,专注于易用性、性能和长期支持。FSDP 的主要好处之一是减少每个 GPU 上的内存占用。这使得能够使用比 DDP 更低的总体内存来训练更大的模型,并利用计算和通信的重叠来高效地训练模型。这种减少的内存压力可以用来训练更大的模型或增加批次大小,从而可能提高整体训练吞吐量。您可以在此处阅读更多关于 PyTorch FSDP 的信息 这里。
本教程中的 FSDP 功能#
Transformer 自动包装策略
混合精度
在设备上初始化 FSDP 模型
分片策略
后向预取
通过流式传输到 CPU 来保存模型检查点
FSDP 工作原理回顾#
高层次上,FSDP 工作如下:
在构造函数中
分片模型参数,每个 rank 只保留自己的分片
在前向传播中
运行 all_gather 从所有 rank 收集所有分片以恢复此 FSDP 单元的完整参数,并运行前向计算
丢弃刚收集的非所有权参数分片以释放内存
在后向传播中
运行 all_gather 从所有 rank 收集所有分片以恢复此 FSDP 单元的完整参数,并运行后向计算
丢弃非所有权参数以释放内存。
运行 reduce_scatter 以同步梯度
微调 HF T5#
HF T5 预训练模型有四种不同的尺寸,从 6000 万参数的小型模型到 110 亿参数的 XXL 模型。在本教程中,我们演示了使用 FSDP 微调 T5 3B 模型以进行文本摘要,使用 WikiHow 数据集。本教程的重点是展示 FSDP 中有助于训练 3B 参数以上的大规模模型的各种可用功能。此外,我们还将介绍专门针对 Transformer 模型的功能。本教程的代码可在 Pytorch 示例 中找到。
设置
1.1 安装最新的 PyTorch
pip3 install torch torchvision torchaudio
1.2 数据集设置
请创建一个 data 文件夹,从 wikihowAll.csv 和 wikihowSep.cs 下载 WikiHow 数据集,并将其放入 data 文件夹。我们将使用来自 summarization_dataset 的 wikihow 数据集。
接下来,我们将以下代码片段添加到名为“T5_training.py”的 Python 脚本中。
注意
本教程的完整源代码可在 PyTorch 示例 中找到。
1.3 导入必要的包
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing_wrapper)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime
1.4 分布式训练设置。这里我们使用两个辅助函数来初始化分布式训练的进程,然后在训练完成后进行清理。在本教程中,我们将使用 torch elastic,通过 torchrun,它将自动设置 worker 的 RANK 和 WORLD_SIZE。
def setup():
# initialize the process group
dist.init_process_group("nccl")
def cleanup():
dist.destroy_process_group()
2.1 设置 HuggingFace T5 模型
def setup_model(model_name):
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
return model, tokenizer
我们还在这里添加了几个用于日期和格式化内存指标的辅助函数。
def get_date_of_run():
"""create date and time for file save uniqueness
example: 2022-05-07-08:31:12_PM'
"""
date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
print(f"--> current date and time of run = {date_of_run}")
return date_of_run
def format_metrics_to_gb(item):
"""quick function to format numbers to gigabyte and round to 4 digit precision"""
metric_num = item / g_gigabyte
metric_num = round(metric_num, ndigits=4)
return metric_num
2.2 定义训练函数
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(2).to(local_rank)
if sampler:
sampler.set_epoch(epoch)
if rank==0:
inner_pbar = tqdm.tqdm(
range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
)
for batch in train_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
optimizer.zero_grad()
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
loss = output["loss"]
loss.backward()
optimizer.step()
fsdp_loss[0] += loss.item()
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
train_accuracy = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(
f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
)
return train_accuracy
2.3 定义验证函数
def validation(model, rank, world_size, val_loader):
model.eval()
correct = 0
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(3).to(local_rank)
if rank == 0:
inner_pbar = tqdm.tqdm(
range(len(val_loader)), colour="green", desc="Validation Epoch"
)
with torch.no_grad():
for batch in val_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
fsdp_loss[0] += output["loss"].item() # sum up batch loss
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
val_loss = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(f"Validation Loss: {val_loss:.4f}")
return val_loss
2.4 定义包装模型在 FSDP 中的分布式训练函数
def fsdp_main(args):
model, tokenizer = setup_model("t5-base")
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
dataset = load_dataset('wikihow', 'all', data_dir='data/')
print(dataset.keys())
print("Size of train dataset: ", dataset['train'].shape)
print("Size of Validation dataset: ", dataset['validation'].shape)
#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)
sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)
setup()
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 2,
'pin_memory': True,
'shuffle': False}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block,
},
)
sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
torch.cuda.set_device(local_rank)
#init_start_event = torch.cuda.Event(enable_timing=True)
#init_end_event = torch.cuda.Event(enable_timing=True)
#init_start_event.record()
bf16_ready = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and LooseVersion(torch.version.cuda) >= "11.0"
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
if bf16_ready:
mp_policy = bfSixteen
else:
mp_policy = None # defaults to fp32
# model is on CPU before input to FSDP
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=mp_policy,
#sharding_strategy=sharding_strategy,
device_id=torch.cuda.current_device())
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
best_val_loss = float("inf")
curr_val_loss = float("inf")
file_save_name = "T5-model-"
if rank == 0:
time_of_run = get_date_of_run()
dur = []
train_acc_tracking = []
val_acc_tracking = []
training_start_time = time.time()
if rank == 0 and args.track_memory:
mem_alloc_tracker = []
mem_reserved_tracker = []
for epoch in range(1, args.epochs + 1):
t0 = time.time()
train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
if args.run_validation:
curr_val_loss = validation(model, rank, world_size, val_loader)
scheduler.step()
if rank == 0:
print(f"--> epoch {epoch} completed...entering save and stats zone")
dur.append(time.time() - t0)
train_acc_tracking.append(train_accuracy.item())
if args.run_validation:
val_acc_tracking.append(curr_val_loss.item())
if args.track_memory:
mem_alloc_tracker.append(
format_metrics_to_gb(torch.cuda.memory_allocated())
)
mem_reserved_tracker.append(
format_metrics_to_gb(torch.cuda.memory_reserved())
)
print(f"completed save and stats zone...")
if args.save_model and curr_val_loss < best_val_loss:
# save
if rank == 0:
print(f"--> entering save model state")
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
#print(f"saving process: rank {rank} done w state_dict")
if rank == 0:
print(f"--> saving model ...")
currEpoch = (
"-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
)
print(f"--> attempting to save model prefix {currEpoch}")
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
print(f"--> saving as model name {save_name}")
torch.save(cpu_state, save_name)
if curr_val_loss < best_val_loss:
best_val_loss = curr_val_loss
if rank==0:
print(f"-->>>> New Val Loss Record: {best_val_loss}")
dist.barrier()
cleanup()
2.5 解析参数并设置主函数
if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
parser.add_argument('--batch-size', type=int, default=4, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=2, metavar='N',
help='number of epochs to train (default: 3)')
parser.add_argument('--lr', type=float, default=.002, metavar='LR',
help='learning rate (default: .002)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--track_memory', action='store_false', default=True,
help='track the gpu memory')
parser.add_argument('--run_validation', action='store_false', default=True,
help='running the validation')
parser.add_argument('--save-model', action='store_false', default=True,
help='For Saving the current Model')
args = parser.parse_args()
torch.manual_seed(args.seed)
fsdp_main(args)
使用 torchrun 运行训练
torchrun --nnodes 1 --nproc_per_node 4 T5_training.py
Transformer 包装策略#
正如在 上一教程 中讨论的,auto_wrap_policy 是 FSDP 的功能之一,可以轻松地自动分片给定的模型,并将模型、优化器和梯度分片放入不同的 FSDP 单元。
对于像 Transformer 编码器-解码器这样的某些架构,模型的某些部分(例如嵌入表)被编码器和解码器共享。在这种情况下,我们需要将嵌入表放置在外部 FSDP 单元中,以便它可以从编码器和解码器进行访问。此外,通过注册 Transformer 的层类,分片计划可以变得更具通信效率。在 PyTorch 1.12 中,FSDP 添加了此支持,现在我们有了 Transformer 的包装策略。
可以按如下方式创建它,其中 T5Block 代表 T5 Transformer 层类(包含 MHSA 和 FFN)。
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block,
},
)
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy)
要查看包装后的模型,您可以轻松打印模型并直观地检查分片和 FSDP 单元。
混合精度#
FSDP 支持灵活的混合精度训练,允许任意降低的精度类型(如 fp16 或 bfloat16)。目前 BFloat16 仅在 Ampere GPU 上可用,因此在使用它之前需要确认原生支持。例如,在 V100 上,BFloat16 仍然可以运行,但由于它不是原生运行,可能会导致显着减慢。
要检查 BFloat16 是否受到原生支持,您可以使用以下方法:
bf16_ready = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and LooseVersion(torch.version.cuda) >= "11.0"
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
FSDP 中混合精度的优势之一是能够对参数、梯度和缓冲区提供不同精度级别的精细控制,如下所示:
fpSixteen = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
buffer_dtype=torch.float16,
)
bfSixteen = MixedPrecision(
param_dtype=torch.bfloat16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16,
# Buffer precision.
buffer_dtype=torch.bfloat16,
)
fp32_policy = MixedPrecision(
param_dtype=torch.float32,
# Gradient communication precision.
reduce_dtype=torch.float32,
# Buffer precision.
buffer_dtype=torch.float32,
)
请注意,如果未指定某些类型(参数、reduce、缓冲区),则它们将不会被转换。
这种灵活性允许用户进行精细控制,例如,仅在降低的精度下进行梯度通信,而所有参数/缓冲区计算都以全精度进行。这在节点内通信是主要瓶颈且参数/缓冲区必须以全精度避免准确性问题的情况下可能很有用。这可以通过以下策略完成:
grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)
在 2.4 中,我们将相关的混合精度策略添加到 FSDP 包装器中。
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen)
在我们的实验中,我们观察到使用 BFloat16 训练可将速度提高高达 4 倍,并且在一些用于增加批次大小的实验中内存减少约 30%。
在设备上初始化 FSDP 模型#
在 1.12 中,FSDP 支持一个 device_id 参数,用于在给定的 device_id 指定的设备上初始化输入 CPU 模块。当整个模型不适合单个 GPU 但适合主机 CPU 内存时,这非常有用。当指定 device_id 时,FSDP 将在每个 FSDP 单元的基础上将模型移动到指定的设备,从而避免 GPU OOM 问题,同时初始化速度比基于 CPU 的初始化快几倍。
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device())
后向预取#
后向预取设置控制下一个 FSDP 单元参数请求的时间。通过将其设置为 BACKWARD_PRE,可以在当前单元计算开始之前,请求并开始接收下一个 FSDP 单元的参数。这会重叠 all_gather 通信和梯度计算,从而可能提高训练速度,但会稍微增加内存消耗。可以在 2.4 节的 FSDP 包装器中使用它,如下所示:
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device(),
backward_prefetch = BackwardPrefetch.BACKWARD_PRE)
backward_prefetch 有两种模式:BACKWARD_PRE 和 BACKWARD_POST。BACKWARD_POST 意味着下一个 FSDP 单元的参数将在当前 FSDP 单元处理完成之前不会被请求,从而最大限度地减少内存开销。在某些情况下,使用 BACKWARD_PRE 可以将模型训练速度提高 2-10%,对于较大的模型甚至可以提高更高的速度。
通过流式传输到 Rank0 CPU 进行模型检查点保存#
要使用 FULL_STATE_DICT 保存模型检查点,该检查点以与本地模型相同的方式保存模型,PyTorch 1.12 提供了一些实用工具来支持保存更大的模型。
首先,可以指定 FullStateDictConfig,允许 state_dict 仅在 rank 0 上填充并卸载到 CPU。
当使用此配置时,FSDP 将逐个 allgather 模型参数,并将它们卸载到 CPU,仅在 rank 0 上进行。当 state_dict 最终保存时,它将仅在 rank 0 上填充并包含 CPU 张量。这避免了大于单个 GPU 内存的模型潜在的 OOM 问题,并允许用户检查模型大小大约等于用户机器可用 CPU RAM 的模型。
可以使用以下方式运行此功能:
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
if rank == 0:
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
torch.save(cpu_state, save_name)
摘要#
在本教程中,我们介绍了 PyTorch 1.12 中 FSDP 的许多新功能,并使用 HF T5 作为运行示例。使用正确的包装策略,特别是对于 Transformer 模型,以及混合精度和后向预取,应该可以加快您的训练运行。此外,诸如在设备上初始化模型以及通过流式传输到 CPU 来保存检查点等功能,应该有助于避免处理大型模型的 OOM 错误。
我们正在积极致力于为下一个版本为 FSDP 添加新功能。如果您有反馈、功能请求、问题或在使用 FSDP 时遇到问题,请随时通过在 PyTorch GitHub 存储库 中打开 issue 来与我们联系。