上下文并行简介#
作者: Xilun Wu, Chien-Chin Huang
注意
在 GitHub 上查看和编辑本教程。
PyTorch 2.7 或更高版本
简介#
上下文并行是一种用于大型语言模型训练的方法,通过将长输入序列分片到多个设备上来减少峰值激活大小。它打破了 Transformer 块中存储激活的峰值内存使用量对输入序列长度的限制。
Ring Attention 是一种新颖的 Attention 层并行实现,对于高性能的上下文并行至关重要。Ring Attention 会混洗 KV 片段并计算部分注意力分数,重复此过程直到每个设备都使用了所有 KV 片段。已实现了两种 Ring Attention 变体:基于 all-gather 的 pass-KV 和 基于 all-to-all 的 pass-KV
基于 all-gather 的 pass-KV 算法用于 Llama3 训练,该算法首先对 key 和 value 张量执行 all-gather,然后计算局部 query 张量块的注意力输出。我们修改后的基于 all-gather 的 pass-KV 算法同时执行 KV 片段的 all-gather,并使用局部的 key 和 value 张量块计算局部 query 张量块的注意力输出,最后计算局部 query 张量和剩余 KV 片段的注意力输出。这允许注意力计算和 all-gather 集合操作之间存在一定程度的重叠。例如,在 Llama3 训练中,我们还将
freq_cis沿着序列维度进行分片。基于 all-to-all 的方法使用交错的 all-to-all 集合操作来环式混洗 KV 片段,以重叠 SDPA (Scaled Dot Product Attention) 计算和下一个 SDPA 所需的 all-to-all 通信。
上下文并行 API 包含两个部分
context_parallel()允许用户创建一个 Python 上下文,在该上下文中,SDPA 函数(torch.nn.functional.scaled_dot_product_attention)将被自动替换为 Ring Attention。要沿着某个维度分片张量,只需将张量及其分片维度分别传递给参数buffers和buffer_seq_dims。我们建议用户将沿序列维度计算的张量添加到buffers中,并沿着该维度对其进行分片。以 Llama3 训练为例,如果buffers中缺少freq_cis,将导致旋转嵌入计算错误。set_rotate_method()允许用户在基于 all-gather 的 pass-KV 方法和基于 all-to-all 的 pass-KV 方法之间进行选择。
设置#
使用 torch.distributed.tensor.experimental.context_parallel(),用户可以轻松地对张量输入进行分片并并行化 SDPA 函数的执行。为了更好地演示此 API 的用法,我们将从一个简单的执行 SDPA 的代码片段开始,然后使用该 API 对其进行并行化。
import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
def sdpa_example():
assert torch.cuda.is_available()
torch.cuda.set_device("cuda:0")
torch.cuda.manual_seed(0)
batch = 8
nheads = 8
qkv_len = 8192
dim = 32
backend = SDPBackend.FLASH_ATTENTION
dtype = (
torch.bfloat16
if backend == SDPBackend.FLASH_ATTENTION
or backend == SDPBackend.CUDNN_ATTENTION
else torch.float32
)
qkv = [
torch.rand(
(batch, nheads, qkv_len, dim),
dtype=dtype,
requires_grad=True,
device='cuda',
)
for _ in range(3)
]
# specify the SDPBackend to use
with sdpa_kernel(backend):
out = F.scaled_dot_product_attention(*qkv, is_causal=True)
if __name__ == "__main__":
sdpa_example()
启用上下文并行#
现在,让我们首先将其改编为一个分布式程序,其中每个进程(rank)都具有相同的张量输入。然后,我们将应用上下文并行 API 对输入进行分片,并将计算分发到各个进程。
# file: cp_sdpa_example.py
import os
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import context_parallel_unshard
from torch.nn.attention import sdpa_kernel, SDPBackend
def context_parallel_sdpa_example(world_size: int, rank: int):
assert torch.cuda.is_available()
assert dist.is_nccl_available()
torch.cuda.set_device(f"cuda:{rank}")
torch.cuda.manual_seed(0)
dist.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("cp",)
)
batch = 8
nheads = 8
qkv_len = 64
dim = 32
backend = SDPBackend.FLASH_ATTENTION
dtype = (
torch.bfloat16
if backend == SDPBackend.FLASH_ATTENTION
or backend == SDPBackend.CUDNN_ATTENTION
else torch.float32
)
qkv = [
torch.rand(
(batch, nheads, qkv_len, dim),
dtype=dtype,
requires_grad=True,
device='cuda',
)
for _ in range(3)
]
# specify the SDPBackend to use
with sdpa_kernel(backend):
out = F.scaled_dot_product_attention(*qkv, is_causal=True)
# make a clean copy of QKV for output comparison
cp_qkv = [t.detach().clone() for t in qkv]
with sdpa_kernel(backend):
# This `context_parallel()` performs two actions:
# 1. Shard the tensor objects in `buffers` in-place along the dimension
# specified in `buffer_seq_dims`, the tensors in `buffers` and their
# sharding dims in `buffer_seq_dims` are organized in the same order.
# 2. Replace the execution of `F.scaled_dot_product_attention` with a
# context-paralleled-enabled Ring Attention.
with context_parallel(
device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
):
cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)
# The output `cp_out` is still sharded in the same way as QKV
# the `context_parallel_unshard` API allows users to easily
# unshard to gain the full tensor.
(cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2])
assert torch.allclose(
cp_out,
out,
atol=(1e-08 if dtype == torch.float32 else 1e-03 * world_size),
)
if __name__ == "__main__":
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
try:
context_parallel_sdpa_example(world_size, rank)
finally:
dist.barrier()
dist.destroy_process_group()
您可以使用命令 torchrun --standalone --nnodes=1 --nproc-per-node=4 cp_sdpa_example.py 在 4 个 GPU 上启动上述上下文并行 SDPA。我们通过将 Ring Attention 的输出与单个 GPU 上的 SDPA 的输出进行比较来展示数值正确性。
选择旋转方法#
您可以使用 torch.distributed.tensor.experimental._attention.set_rotate_method() 在 Ring Attention 中选择所需的片段旋转方法。
# file: cp_sdpa_example.py
from torch.distributed.tensor.experimental._attention import set_rotate_method
set_rotate_method("alltoall") # rotate shards using all-to-all
with sdpa_kernel(backend):
with context_parallel(
device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
):
cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)
默认的旋转方法是基于 all-gather 的 pass-KV。
结论#
在本教程中,我们学习了如何使用我们的上下文并行 API 轻松地沿序列维度并行化 SDPA 计算。有关设计和实现细节、性能分析以及 TorchTitan 中的端到端训练示例,请参阅我们在 PyTorch 原生长上下文训练 上的帖子。