使用张量并行 (TP) 进行大规模 Transformer 模型训练#
创建日期:2024年4月19日 | 最后更新:2025年7月18日 | 最后验证:2024年11月5日
作者: Wanchao Liang, Tianyu Liu
注意
在 github 上查看和编辑本教程。
本教程演示了如何使用张量并行和完全分片数据并行(Fully Sharded Data Parallel)在数百到数千个 GPU 上训练类似 Transformer 的大型模型。
先决条件
安装了 CUDA/Linux 的 PyTorch 2.3.0 或更高版本
张量并行是如何工作的?#
张量并行 (TP) 最初在 Megatron-LM 论文中提出,它是训练大规模 Transformer 模型的一种高效模型并行技术。本教程中提到的序列并行 (SP) 是张量并行的一种变体,它在 nn.LayerNorm 或 RMSNorm 的序列维度上进行分片,以在训练期间进一步节省激活内存。随着模型规模的增大,激活内存成为瓶颈,因此在张量并行训练中,通常会对 LayerNorm 或 RMSNorm 层应用序列并行。
图 1 展示了 Transformer 模型 MLP 和自注意力层中张量并行风格的分片,其中注意力/MLP 中的矩阵乘法通过分片计算完成 (图像来源)#
总体而言,PyTorch 张量并行的工作方式如下:
分片初始化
确定应用于每一层的
ParallelStyle,并通过调用parallelize_module对初始化的模块进行分片。并行化后的模块将其模型参数交换为 DTensor,DTensor 将负责使用分片计算运行并行化后的模块。
运行时前向/后向传播
根据用户为每个
ParallelStyle指定的输入/输出 DTensor 布局,它将运行适当的通信操作来转换输入/输出的 DTensor 布局(例如allreduce、allgather和reduce_scatter)。为并行化层运行分片计算以节省计算/内存(例如
nn.Linear,nn.Embedding)。
何时以及为什么应该应用张量并行#
PyTorch 完全分片数据并行 (FSDP) 已经具备将模型训练扩展到特定数量 GPU 的能力。然而,当需要在模型规模和 GPU 数量方面进一步扩展模型训练时,会出现许多额外挑战,可能需要将张量并行与 FSDP 结合使用。
随着世界大小(GPU 数量)变得过大(超过 128/256 个 GPU),FSDP 集合通信(如
allgather)会受到环形延迟的主导。通过在 FSDP 之上实现 TP/SP,可以将 FSDP 的世界大小减少 8 倍(即仅在主机间应用 FSDP),从而等量减少延迟成本。当达到数据并行限制(即由于收敛性和 GPU 内存限制,无法将全局批量大小提高到 GPU 数量以上)时,张量/序列并行是唯一已知的可以“估算”全局批量大小并随着更多 GPU 继续扩展的方法。这意味着模型大小和 GPU 数量都可以继续扩展。
对于某些类型的模型,当局部批量大小变小时,TP/SP 可以产生针对浮点运算 (FLOPS) 更优化的矩阵乘法形状。
那么,在预训练时,达到这些限制有多容易?目前,使用数千个 GPU 预训练具有数十亿或数万亿个 Token 的大型语言模型 (LLM) 可能需要数月时间。
在大规模训练 LLM 时,总是会遇到第 1 个限制。例如,Llama 2 70B 在 2000 个 GPU 上训练了 35 天,在 2k 规模下需要多维并行。
当 Transformer 模型变大时(如 Llama 2 70B),也会很快遇到第 2 个限制。由于内存和收敛性限制,即使局部
batch_size=1,也无法单独使用 FSDP。例如,Llama 2 的全局批量大小为 1K,因此在 2K 个 GPU 上无法仅使用数据并行。
如何应用张量并行#
PyTorch 张量并行 API 提供了一组模块级原语 (ParallelStyle) 来配置模型每一层的分片,包括:
ColwiseParallel和RowwiseParallel:以列或行方式对nn.Linear和nn.Embedding进行分片。SequenceParallel:在nn.LayerNorm,nn.Dropout,RMSNormPython等上执行分片计算。PrepareModuleInput和PrepareModuleOutput:配置具有适当通信操作的模块输入/输出分片布局。
为了演示如何使用 PyTorch 原生张量并行 API,让我们看一个常见的 Transformer 模型。在本教程中,我们使用最新的 Llama 2 模型 作为参考 Transformer 模型实现,因为它在社区中也被广泛使用。
由于张量并行在一组设备上对单个张量进行分片,我们需要首先设置分布式环境(例如 NCCL 通信器)。张量并行是一种类似于 PyTorch DDP/FSDP 的单程序多数据 (SPMD) 分片算法,它在底层利用 PyTorch DTensor 执行分片。它还利用 DeviceMesh 抽象(在底层管理 ProcessGroups)进行设备管理和分片。要了解如何利用 DeviceMesh 设置多维并行,请参阅 此教程。张量并行通常在每个主机内工作,所以让我们首先初始化一个连接主机内 8 个 GPU 的 DeviceMesh。
from torch.distributed.device_mesh import init_device_mesh
tp_mesh = init_device_mesh("cuda", (8,))
现在我们已经初始化了 DeviceMesh,让我们详细了解一下 Llama 2 模型架构,看看我们应该如何执行张量并行分片。在这里,我们专注于核心 TransformerBlock,其中 Transformer 模型堆叠相同的 TransformerBlock 来扩展模型。
核心 TransformerBlock 由一个 Attention 层和一个 FeedForward 层组成。让我们先看更简单的 FeedForward 层。对于 FeedForward 层,它由三个线性层组成,执行 SwiGLU 风格的 MLP,查看其前向函数:
# forward in the FeedForward layer
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
它同时执行 w1 和 w3 矩阵乘法,然后用 w1/w3 线性投影结果的组合进行 w2 矩阵乘法。这意味着我们可以使用张量并行论文中的思想,以列方式对 w1/w3 线性层进行分片,并以行方式对 w2 线性层进行分片,这样在所有三层结束时只有一个 allreduce 通信发生。使用 PyTorch 原生张量并行,我们可以简单地为 FeedForward 层创建一个 parallelize_plan,如下所示:
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
layer_tp_plan = {
# by default ColwiseParallel input layouts is replicated
# and RowwiseParallel output layouts is replicated
"feed_foward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}
这就是我们使用 PyTorch 张量并行 API 为 FeedForward 层配置分片的方式。注意,用户只需要指定如何对各个层进行分片,通信(例如 allreduce)会在底层自动发生。
继续看 Attention 层。它由 wq, wk, wv 线性层组成,用于将输入投影到 q/k/v,然后它使用 wo 线性层执行注意力计算和输出投影。这里的张量并行旨在对 q/k/v 投影执行列式分片,对 wo 线性投影执行行式分片。因此,我们可以将 Attention 计划添加到我们刚刚起草的 tp_plan 中。
layer_tp_plan = {
# by default ColwiseParallel input layouts is replicated
# and RowwiseParallel output layouts is replicated
"attention.wq": ColwiseParallel(use_local_output=False),
"attention.wk": ColwiseParallel(use_local_output=False),
"attention.wv": ColwiseParallel(use_local_output=False),
"attention.wo": RowwiseParallel(),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}
这几乎就是我们需要对 TransformerBlock 应用张量并行所需的 layer_tp_plan。但是,有一点需要注意:当按列对线性层进行分片时,线性层的输出将在最后一个张量维度上被分片,而按行分片的线性层直接接受在最后一个维度上分片的输入。如果在列式线性层和行式线性层之间有任何其他的张量操作(例如 view 操作),我们需要将相关的形状相关操作调整为分片形状。
对于 Llama 模型,在注意力层中,有几个与形状相关的 view 操作。具体来说,对于 wq/wk/wv 线性层中的列式并行,激活张量在 num_heads 维度上被分片。为了管理全局和局部 num_heads 之间的差异,我们应该设置 use_local_output=False 以确保输出是一个 DTensor。与普通张量不同,DTensor 意识到并行计划,并将自动处理 num_heads 维度的变化。
最后,我们需要调用 parallelize_module API 来使每个 TransformerBlock 的计划生效。在底层,它将 Attention 和 FeedForward 层内的模型参数分发给 DTensors,并根据需要为模型输入和输出注册通信钩子(分别在每个模块之前和之后)。
for layer_id, transformer_block in enumerate(model.layers):
layer_tp_plan = {...} # i.e. the plan we just generated
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_tp_plan,
)
现在我们已经阐述了每个 TransformerBlock 的分片计划,第一层通常有一个 nn.Embedding,最后一层有一个最终的 nn.Linear 投影层,用户可以选择对第一个 nn.Embedding 进行行式或列式分片,并对最后一个 nn.Linear 投影层进行列式分片,并指定适当的输入和输出布局。这是一个示例:
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
output_layouts=Replicate(),
),
}
)
注意
如果要划分的模型太大而无法放入 CPU 内存,可以使用 meta 设备初始化(例如,首先在 meta 设备上初始化模型,对层进行分片,然后将模型具体化),或者在 Transformer 模型初始化期间逐层并行化 TransformerBlock。
将序列并行应用于 LayerNorm/RMSNorm 层#
序列并行工作在上述张量并行之上。与仅在 Attention 模块和 FeedForward 模块内分片张量并保持其模块输入和输出(即前向传播中的激活和反向传播中的梯度)复制的基本张量并行相比,序列并行将它们保持在序列维度上进行分片。
在一个典型的 TransformerBlock 中,前向函数结合了归一化层(LayerNorm 或 RMSNorm)、注意力层、前馈层和残差连接。例如:
# forward in a TransformerBlock
def forward(self, x):
h = x + self.attention(self.attention_norm(x))
out = h + self.feed_forward(self.ffn_norm(h))
return out
在大多数用例中,在 Attention 和 FeedForward 模块之外,激活(和梯度)的形状为 [batch size, sequence length, hidden dimension]。用 DTensor 的语言来说,序列并行在模块的前向/反向传播中都使用 Shard(1) 布局进行激活计算。按照前面的代码示例,下面的代码演示了我们如何将序列并行应用于 TransformerBlock 内的归一化层:
首先,让我们导入序列并行所需的依赖项。
from torch.distributed.tensor.parallel import (
PrepareModuleInput,
SequenceParallel,
)
接下来,让我们调整 layer_tp_plan,以在 RMSNorm 层上启用序列并行。
layer_tp_plan = {
# Now the input and output of SequenceParallel has Shard(1) layouts,
# to represent the input/output tensors sharded on the sequence dimension
"attention_norm": SequenceParallel(),
"attention": PrepareModuleInput(
input_layouts=(Shard(1), Replicate()),
desired_input_layouts=(Replicate(), Replicate()),
),
"attention.wq": ColwiseParallel(use_local_output=False),
"attention.wk": ColwiseParallel(use_local_output=False),
"attention.wv": ColwiseParallel(use_local_output=False),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
}
可以看到,我们现在使用 PrepareModuleInput 将 Attention 和 FeedForward 层的模块输入布局从 Shard(1) 修改为 Replicate(),并将它们的输出布局标记为 Shard(1)。就像张量并行一样,只需要指定输入和输出的张量分片布局,层与层之间的通信就会自动发生。
注意,使用序列并行时,我们假设 TransformerBlock 的输入和输出始终在序列维度上分片,这样多个 TransformerBlock 就可以无缝连接。这可以通过明确指定起始 nn.Embedding 层的输出和最终 nn.Linear 投影层的输入为 Shard(1) 来实现。
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate()
),
}
)
应用损失并行#
损失并行是一种相关技术,用于在计算损失函数时节省内存和通信,因为模型输出通常非常大。在损失并行中,当模型输出在(通常很大)的词汇维度上进行分片时,可以高效地计算交叉熵损失,而无需将所有模型输出收集到每个单独的 GPU 上。这不仅显着降低了内存消耗,而且通过减少通信开销和并行进行分片计算,提高了训练速度。下图简要说明了损失并行如何通过进行分片计算来避免将所有模型输出收集到每个 GPU。
图 2. 在一个 GPU 上进行损失并行的交叉熵损失前向计算。蓝色表示分片张量;绿色表示复制的张量;黄色表示具有部分值的张量(将进行 all-reduce)。黑色箭头是局部计算;红色箭头是 GPU 之间的功能性集合通信。#
在 PyTorch 张量并行 API 中,可以通过上下文管理器 loss_parallel 启用损失并行,通过它可以直接使用 torch.nn.functional.cross_entropy 或 torch.nn.CrossEntropyLoss,而无需修改代码的其他部分。
为了应用损失并行,模型预测(通常形状为 [batch size, sequence length, vocabulary size])应该在词汇维度上进行分片。这可以通过标记最后一个线性投影层的输出布局来轻松完成。
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
# use DTensor as the output
use_local_output=False,
),
},
)
在上面的代码中,我们还在输出之前的归一化层上应用了序列并行。我们应用 use_local_output=False 让输出保持为 DTensor,以与 loss_parallel 上下文管理器配合使用。之后,可以简单地调用 cross_entropy 损失函数,如下所示。注意,反向计算也需要在上下文中进行。
import torch.nn.functional as F
from torch.distributed.tensor.parallel import loss_parallel
pred = model(input_ids)
with loss_parallel():
# assuming pred and labels are of the shape [batch, seq, vocab]
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
loss.backward()
将张量并行与完全分片数据并行结合使用#
现在我们已经展示了如何将张量/序列并行应用于模型,让我们看看张量并行和完全分片数据并行如何协同工作。由于张量并行会产生阻塞计算的通信,我们希望确保它在快速通信通道(如 NVLink)内运行。在实践中,我们通常在每个主机内应用张量并行,并在主机之间应用完全分片数据并行。
图 3. FSDP 和 TP 工作在不同的设备维度上,FSDP 通信发生在主机间,TP 通信发生在主机内。#
这种二维并行模式可以通过二维 DeviceMesh 轻松表达,我们只需要将每个“子”DeviceMesh 传递给每个单独的并行 API。
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
from torch.distributed.fsdp import fully_shard
# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
mesh_2d = init_device_mesh("cuda", (8, 8))
tp_mesh = mesh_2d["tp"] # a submesh that connects intra-host devices
dp_mesh = mesh_2d["dp"] # a submesh that connects inter-host devices
model = Model(...)
tp_plan = {...}
# apply Tensor Parallel intra-host on tp_mesh
model_tp = parallelize_module(model, tp_mesh, tp_plan)
# apply FSDP inter-host on dp_mesh
model_2d = fully_shard(model_tp, mesh=dp_mesh, ...)
这将允许我们轻松地在每个主机内(主机内)应用张量并行,并在主机之间(主机间)应用 FSDP,且对 Llama 模型进行 **0 代码更改**。张量(模型)并行和数据并行技术的结合提供了在大量 GPU 上继续增加模型规模和高效训练的能力。
结论#
本教程演示了如何使用张量并行与完全分片数据并行相结合,在数百到数千个 GPU 上训练类似 Transformer 的大型模型。它解释了如何将张量并行应用于模型的不同部分,并且对模型本身 **无需代码更改**。张量并行是用于大规模训练的高效模型并行技术。
要查看本教程中解释的完整端到端代码示例,请参阅 pytorch/examples 仓库中的 张量并行示例。