评价此页

使用张量并行(TP)进行大规模 Transformer 模型训练#

创建日期:2024 年 4 月 19 日 | 最后更新:2025 年 7 月 18 日 | 最后验证:2024 年 11 月 5 日

作者Wanchao Liang, Tianyu Liu

注意

editgithub 上查看和编辑此教程。

本教程演示了如何使用张量并行(TP)和完全分片数据并行(FSDP)在数百到数千个 GPU 上训练大型 Transformer 类模型。

先决条件

张量并行如何工作?#

张量并行(TP)最初在 Megatron-LM 论文中提出,是一种用于训练大规模 Transformer 模型的高效模型并行技术。序列并行(SP)是我们在此教程中提到的张量并行的变体,它在序列维度上对 `nn.LayerNorm` 或 `nn.RMSNorm` 进行分片,以进一步节省训练期间的激活内存。随着模型变得越来越大,激活内存成为瓶颈,因此在张量并行训练中,通常对 `LayerNorm` 或 `RMSNorm` 层应用序列并行。

Megatron-LM TP

图 1. 在 Transformer 模型的 MLP 和自注意力层上表示张量并行风格的分片,其中注意力/MLP 中的矩阵乘法通过分片计算完成(图片来源#

总的来说,PyTorch 张量并行的工作方式如下:

分片初始化

  • 确定将哪个 `ParallelStyle` 应用于每个层,并通过调用 `parallelize_module` 来分片初始化后的模块。

  • 并行化后的模块的模型参数将被切换为 DTensors,DTensor 将负责使用分片计算来运行并行化后的模块。

运行时前向/后向传播

  • 根据用户为每个 `ParallelStyle` 指定的输入/输出 DTensor 布局,它将运行适当的通信操作来转换输入/输出的 DTensor 布局(例如 `allreduce`、`allgather` 和 `reduce_scatter`)。

  • 运行并行化层的分片计算以节省计算/内存(例如 `nn.Linear`、`nn.Embedding`)。

何时以及为何应该应用张量并行#

PyTorch 完全分片数据并行(FSDP)已经具备了将模型训练扩展到特定数量 GPU 的能力。然而,当在模型大小和 GPU 数量方面进一步扩展模型训练时,许多额外的挑战会出现,这可能需要将张量并行与 FSDP 相结合。

  1. 当世界规模(GPU 数量)变得过大(超过 128/256 个 GPU)时,FSDP 的集合通信(如 `allgather`)会被环延迟所主导。通过在 FSDP 之上实现 TP/SP,可以将 FSDP 世界规模减少 8 倍,通过仅对主机内部应用 FSDP,从而将延迟成本降低相同的数量。

  2. 在数据并行达到极限,由于收敛性和 GPU 内存限制而无法将全局批次大小提高到超过 GPU 数量时,张量/序列并行是唯一已知的方法可以“大致”确定全局批次大小并继续使用更多 GPU 进行扩展。这意味着模型大小和 GPU 数量都可以继续扩展。

  3. 对于某些类型的模型,当局部批次大小变小时,TP/SP 可以产生更优化的浮点运算(FLOPS)矩阵乘法形状。

那么,在预训练时,达到这些限制有多容易?目前,预训练一个具有数十亿或数万亿个 token 的大型语言模型(LLM)可能需要数月时间,即使使用数千个 GPU。

  • 在 LLM 大规模训练时,总会遇到限制 1。例如,Llama 2 70B 在 2k 个 GPU 上训练了 35 天,在 2k 的规模下需要多维并行。

  • 当 Transformer 模型变得更大(例如 Llama2 70B)时,也会很快达到限制 2。即使局部 `batch_size=1`,由于内存和收敛性限制,也无法仅使用 FSDP。例如,Llama 2 的全局批次大小为 1K,因此在 2K 个 GPU 上无法仅使用数据并行。

如何应用张量并行#

PyTorch 张量并行 API 提供了一组模块级原语(`ParallelStyle`),用于配置模型每个单独层的分片,包括:

  • `ColwiseParallel` 和 `RowwiseParallel`:按列或按行分片 `nn.Linear` 和 `nn.Embedding`。

  • `SequenceParallel`:对 `nn.LayerNorm`、`nn.Dropout`、`RMSNormPython` 等执行分片计算。

  • `PrepareModuleInput` 和 `PrepareModuleOutput`:使用适当的通信操作配置模块输入/输出的分片布局。

为了演示如何使用 PyTorch 原生的张量并行 API,让我们看一下一个常见的 Transformer 模型。在本教程中,我们使用最新的 Llama2 模型作为参考 Transformer 模型实现,因为它也在社区中被广泛使用。

由于张量并行会在一组设备上分片单个张量,因此我们需要先设置分布式环境(例如 NCCL communicators)。张量并行是一种单程序多数据(SPMD)分片算法,类似于 PyTorch DDP/FSDP,它在底层利用 PyTorch DTensor 来执行分片。它还利用 DeviceMesh 抽象(底层管理 ProcessGroups)来进行设备管理和分片。有关如何使用 DeviceMesh 设置多维并行,请参阅本教程。张量并行通常在每个主机内部工作,所以我们先初始化一个连接主机内 8 个 GPU 的 DeviceMesh。

from torch.distributed.device_mesh import init_device_mesh

tp_mesh = init_device_mesh("cuda", (8,))

现在我们已经初始化了 DeviceMesh,让我们详细看看 Llama 2 模型架构,并了解如何执行张量并行分片。这里我们重点关注核心的 `TransformerBlock`,其中 Transformer 模型堆叠相同的 `TransformerBlock` 以扩展模型。

核心的 `TransformerBlock` 由一个 `Attention` 层和一个 `FeedForward` 层组成。让我们先看看更简单的 `FeedForward` 层。对于 `FeedForward` 层,它包含三个 Linear 层,执行 SwiGLU 风格的 MLP,查看其 forward 函数:

# forward in the FeedForward layer
def forward(self, x):
    return self.w2(F.silu(self.w1(x)) * self.w3(x))

它并行执行 `w1` 和 `w3` 的矩阵乘法,然后执行 `w2` 的矩阵乘法,结果是 `w1`/`w3` 线性投影结果的组合。这意味着我们可以借鉴张量并行论文中的思想,将 `w1`/`w3` Linear 层按列分片,并将 `w2` Linear层按行分片,这样在所有三个层结束时只有一个 `allreduce` 通信。使用 PyTorch 原生的张量并行,我们可以为 `FeedForward` 层简单地创建一个 `parallelize_plan`,如下所示:

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module

layer_tp_plan = {
    # by default ColwiseParallel input layouts is replicated
    # and RowwiseParallel output layouts is replicated
    "feed_foward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(),
    "feed_forward.w3": ColwiseParallel(),
}

这就是我们使用 PyTorch 张量并行 API 为 `FeedForward` 层配置分片的方式。请注意,用户只需指定如何分片各个层,通信(例如 `allreduce`)将在后台自动进行。

接下来是 `Attention` 层。它由 `wq`、`wk`、`wv` Linear 层组成,用于将输入投影到 `q` / `k` / `v`,然后它执行注意力计算并与 `wo` Linear 层进行输出投影。张量并行在这里旨在对 q/k/v 投影执行按列分片,并对 `wo` Linear 投影执行按行分片。因此,我们可以将 Attention 的计划添加到我们刚刚起草的 `tp_plan` 中:

layer_tp_plan = {
    # by default ColwiseParallel input layouts is replicated
    # and RowwiseParallel output layouts is replicated
    "attention.wq": ColwiseParallel(use_local_output=False),
    "attention.wk": ColwiseParallel(use_local_output=False),
    "attention.wv": ColwiseParallel(use_local_output=False),
    "attention.wo": RowwiseParallel(),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(),
    "feed_forward.w3": ColwiseParallel(),
}

这几乎是我们用于对 `TransformerBlock` 应用张量并行所需的 `layer_tp_plan`。但是,我们需要注意的一点是,当按列分片线性层时,线性层的输出将在最后一个张量维度上分片,而按行分片线性层直接接受一个在最后一个维度上分片的输入。如果在按列分片的线性层和按行分片的线性层之间有任何其他张量操作(例如 view 操作),我们需要调整相关的形状相关操作以适应分片形状。

对于 Llama 模型,在注意力层中,有几个与形状相关的 view 操作。具体来说,对于 `wq` / `wk` / `wv` 线性层的按列并行,激活张量在 `num_heads` 维度上进行分片。为了管理全局 `num_heads` 和局部 `num_heads` 之间的差异,我们应该设置 `use_local_output=False` 以确保输出是 DTensor。与常规张量不同,DTensor 了解并行计划,并将自动处理 `num_heads` 维度的变化。

最后,我们需要调用 `parallelize_module` API 来使每个 `TransformerBlock` 的计划生效。在底层,它将 `Attention` 和 `FeedForward` 层中的模型参数分布到 DTensors,并(如果需要)为模块的输入和输出(分别在每个模块之前和之后)注册通信钩子。

for layer_id, transformer_block in enumerate(model.layers):
    layer_tp_plan = {...}  # i.e. the plan we just generated

    parallelize_module(
        module=transformer_block,
        device_mesh=tp_mesh,
        parallelize_plan=layer_tp_plan,
    )

现在我们已经详细说明了每个 `TransformerBlock` 的分片计划,通常在第一层有一个 `nn.Embedding` 层,在最后一层有一个 `nn.Linear` 投影层。用户可以选择按行或按列分片第一个 `nn.Embedding` 层,并按列分片最后一个 `nn.Linear` 投影层,并指定适当的输入和输出布局。下面是一个示例:

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
        ),
        "output": ColwiseParallel(
            output_layouts=Replicate(),
        ),
    }
)

注意

如果需要分区的模型太大而无法放入 CPU 内存,可以采用 `meta` 设备初始化(例如,先在 meta 设备上初始化模型,分片层,然后具体化模型),或者在 Transformer 模型初始化过程中逐层并行化 `TransformerBlock`。

将序列并行应用于 `LayerNorm/RMSNorm` 层#

序列并行建立在上面介绍的张量并行之上。与仅在 `Attention` 模块和 `FeedForward` 模块内对张量进行分片并保持其模块输入和输出(即前向传播中的激活和后向传播中的梯度)副本的常规张量并行相比,序列并行在序列维度上保持它们的分片状态。

在一个典型的 `TransformerBlock` 中,forward 函数结合了 norm 层(`LayerNorm` 或 `RMSNorm`)、一个注意力层、一个前馈层以及残差连接。例如:

# forward in a TransformerBlock
def forward(self, x):
    h = x + self.attention(self.attention_norm(x))
    out = h + self.feed_forward(self.ffn_norm(h))
    return out

在大多数用例中,`Attention` 和 `FeedForward` 模块之外的激活(和梯度)的形状为 `[batch size, sequence length, hidden dimension]`。用 DTensor 的术语来说,序列并行在模块的前向/后向计算中使用 `Shard(1)` 布局。遵循之前的代码示例,下面的代码演示了如何将序列并行应用于 `TransformerBlock` 中的 norm 层:

首先,让我们导入序列并行所需的依赖项:

from torch.distributed.tensor.parallel import (
    PrepareModuleInput,
    SequenceParallel,
)

接下来,让我们调整 `layer_tp_plan` 以在 `RMSNorm` 层上启用序列并行:

layer_tp_plan = {
    # Now the input and output of SequenceParallel has Shard(1) layouts,
    # to represent the input/output tensors sharded on the sequence dimension
    "attention_norm": SequenceParallel(),
    "attention": PrepareModuleInput(
        input_layouts=(Shard(1), Replicate()),
        desired_input_layouts=(Replicate(), Replicate()),
    ),
    "attention.wq": ColwiseParallel(use_local_output=False),
    "attention.wk": ColwiseParallel(use_local_output=False),
    "attention.wv": ColwiseParallel(use_local_output=False),
    "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
    "ffn_norm": SequenceParallel(),
    "feed_forward": PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    ),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w3": ColwiseParallel(),
}

可以看到,我们现在使用 `PrepareModuleInput` 来修改 Attention 和 FeedForward 层的模块输入布局,从 `Shard(1)` 更改为 `Replicate()`,并将其输出布局标记为 `Shard(1)`。与张量并行发生的情况类似,用户只需指定输入和输出的张量分片布局,层之间的通信将自动发生。

请注意,使用序列并行时,我们假设 `TransformerBlock` 的输入和输出始终在序列维度上分片,以便可以无缝地连接多个 `TransformerBlock`。这可以通过显式地将第一个 `nn.Embedding` 层的输出和最后一个 `nn.Linear` 投影层的输入指定为 `Shard(1)` 来实现。

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
            output_layouts=Shard(1),
        ),
        "norm": SequenceParallel(),
        "output": ColwiseParallel(
            input_layouts=Shard(1),
            output_layouts=Replicate()
        ),
    }
)

应用损失并行#

损失并行是一项相关的技术,用于在计算损失函数时节省内存和通信,因为模型输出通常非常大。在损失并行中,当模型输出在(通常非常大的)词汇表维度上分片时,交叉熵损失可以高效地计算,而无需将所有模型输出收集到每个 GPU 上。这不仅显着减少了内存消耗,而且通过减少通信开销并并行执行分片计算来提高了训练速度。下图简要说明了损失并行如何通过执行分片计算来避免将所有模型输出收集到每个 GPU 上。

loss parallel

图 2. 在一个 GPU 上使用损失并行计算的交叉熵损失前向传播。蓝色表示分片张量;绿色表示复制张量;黄色表示具有部分值的张量(待 all-reduce)。黑色箭头表示本地计算;红色箭头表示 GPU 之间的函数式集合通信。#

在 PyTorch 张量并行 API 中,可以通过上下文管理器 `loss_parallel` 来启用损失并行,使用它可以直接使用 `torch.nn.functional.cross_entropy` 或 `torch.nn.CrossEntropyLoss`,而无需修改代码的其他部分。

为了应用损失并行,模型预测(通常形状为 `[batch size, sequence length, vocabulary size]`)应该在词汇表维度上进行分片。这可以通过标记最后一个线性投影层输出的输出布局轻松实现:

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
            output_layouts=Shard(1),
        ),
        "norm": SequenceParallel(),
        "output": ColwiseParallel(
            input_layouts=Shard(1),
            # use DTensor as the output
            use_local_output=False,
        ),
    },
)

在上面的代码中,我们还对输出之前的 norm 层应用了序列并行。我们应用 `use_local_output=False` 以使输出保持为 DTensor,以便与 `loss_parallel` 上下文管理器配合使用。之后,可以像下面那样简单地调用 cross_entropy 损失函数。请注意,后向传播也需要在此上下文内进行。

import torch.nn.functional as F
from torch.distributed.tensor.parallel import loss_parallel

pred = model(input_ids)
with loss_parallel():
    # assuming pred and labels are of the shape [batch, seq, vocab]
    loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
    loss.backward()

将张量并行与完全分片数据并行结合#

既然我们已经展示了如何将张量/序列并行应用于模型,我们还可以看看张量并行和完全分片数据并行如何协同工作。由于张量并行会产生阻塞计算的通信,因此我们希望确保它在快速通信通道(如 NVLink)内运行。实际上,我们通常在每个主机内部应用张量并行,并在主机之间应用完全分片数据并行。

fsdp + tp

图 3. FSDP 和 TP 在独立的设备维度上工作,FSDP 通信发生在主机之间,TP 通信发生在主机内部。#

这种二维并行模式可以通过二维 DeviceMesh 轻松表达,我们只需要将每个“子” DeviceMesh 传递给每个单独的并行 API:

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
from torch.distributed.fsdp import fully_shard

# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
mesh_2d = init_device_mesh("cuda", (8, 8))
tp_mesh = mesh_2d["tp"] # a submesh that connects intra-host devices
dp_mesh = mesh_2d["dp"] # a submesh that connects inter-host devices

model = Model(...)

tp_plan = {...}

# apply Tensor Parallel intra-host on tp_mesh
model_tp = parallelize_module(model, tp_mesh, tp_plan)
# apply FSDP inter-host on dp_mesh
model_2d = fully_shard(model_tp, mesh=dp_mesh, ...)

这将使我们能够在每个主机内部(主机内)轻松应用张量并行,并在主机之间(主机间)应用 FSDP,并且 Llama 模型 **无需更改任何代码**。张量(模型)并行和数据并行技术相结合,能够通过大量 GPU 继续增加模型大小并高效地进行训练。

结论#

本教程演示了如何将张量并行与完全分片数据并行结合,在数百到数千个 GPU 上训练大型 Transformer 类模型。它解释了如何将张量并行应用于模型的不同部分,而 **无需修改模型本身的代码**。张量并行是一种用于大规模训练的高效模型并行技术。

要查看本教程中解释的完整的端到端代码示例,请参阅 pytorch/examples 存储库中的 张量并行示例