评价此页

分布式流水线并行介绍#

创建日期:2024 年 7 月 9 日 | 最后更新:2025 年 11 月 5 日 | 最后验证:2024 年 11 月 5 日

作者: Howard Huang

注意

editgithub 上查看并编辑此教程。

本教程使用 GPT 风格的 Transformer 模型来演示如何通过 torch.distributed.pipelining API 实现分布式流水线并行。

您将学到什么
  • 如何使用 torch.distributed.pipelining API

  • 如何将流水线并行应用于 Transformer 模型

  • 如何在微批次(microbatches)集合上利用不同的调度策略

先决条件

设置#

通过 torch.distributed.pipelining,我们将对模型的执行进行分区,并在微批次上调度计算。我们将使用一个简化的 Transformer 解码器模型。该模型架构仅用于教学目的,包含多个 Transformer 解码器层,旨在演示如何将模型拆分为不同的分块。首先,让我们定义模型。

import torch
import torch.nn as nn
from dataclasses import dataclass

@dataclass
class ModelArgs:
   dim: int = 512
   n_layers: int = 8
   n_heads: int = 8
   vocab_size: int = 10000

class Transformer(nn.Module):
   def __init__(self, model_args: ModelArgs):
      super().__init__()

      self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)

      # Using a ModuleDict lets us delete layers witout affecting names,
      # ensuring checkpoints will correctly save and load.
      self.layers = torch.nn.ModuleDict()
      for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = nn.TransformerDecoderLayer(model_args.dim, model_args.n_heads)

      self.norm = nn.LayerNorm(model_args.dim)
      self.output = nn.Linear(model_args.dim, model_args.vocab_size)

   def forward(self, tokens: torch.Tensor):
      # Handling layers being 'None' at runtime enables easy pipeline splitting
      h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

      for layer in self.layers.values():
            h = layer(h, h)

      h = self.norm(h) if self.norm else h
      output = self.output(h).clone() if self.output else h
      return output

然后,我们需要在脚本中导入必要的库并初始化分布式训练过程。在此案例中,我们定义了一些后续脚本中会用到的全局变量。

import os
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe

global rank, device, pp_group, stage_index, num_stages
def init_distributed():
   global rank, device, pp_group, stage_index, num_stages
   rank = int(os.environ["LOCAL_RANK"])
   world_size = int(os.environ["WORLD_SIZE"])
   device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
   dist.init_process_group()

   # This group can be a sub-group in the N-D parallel case
   pp_group = dist.new_group()
   stage_index = rank
   num_stages = world_size

rankworld_sizeinit_process_group() 的代码对你来说应该很熟悉,因为它们在所有分布式程序中都很常用。流水线并行特有的全局变量包括 pp_group(将用于发送/接收通信的进程组)、stage_index(在本例中,每个阶段对应一个 rank,因此索引等同于 rank)以及 num_stages(等同于 world_size)。

num_stages 用于设置流水线并行调度中使用的阶段数量。例如,当 num_stages=4 时,一个微批次在完成计算前需要经过 4 次前向传播和 4 次反向传播。stage_index 对于框架了解阶段间的通信方式至关重要。例如,对于第一阶段(stage_index=0),它将直接使用来自数据加载器(dataloader)的数据,无需从任何之前的节点接收数据即可执行计算。

第一步:对 Transformer 模型进行分区#

有两种不同的模型分区方式。

第一种是手动模式,即我们可以通过删除模型属性的特定部分来手动创建两个模型实例。在此双阶段(2 个 rank)的示例中,模型被一分为二。

def manual_model_split(model) -> PipelineStage:
   if stage_index == 0:
      # prepare the first stage model
      for i in range(4, 8):
            del model.layers[str(i)]
      model.norm = None
      model.output = None

   elif stage_index == 1:
      # prepare the second stage model
      for i in range(4):
            del model.layers[str(i)]
      model.tok_embeddings = None

   stage = PipelineStage(
      model,
      stage_index,
      num_stages,
      device,
   )
   return stage

我们可以看到,第一阶段没有层归一化(layer norm)或输出层,仅包含前四个 Transformer 块。第二阶段没有输入嵌入层(input embedding layers),但包含输出层和最后的四个 Transformer 块。函数随后会返回当前 rank 的 PipelineStage

第二种方法是基于追踪器(tracer-based)的模式,它根据 split_spec 参数自动拆分模型。使用流水线规范,我们可以指示 torch.distributed.pipelining 在何处拆分模型。在以下代码块中,我们在第 4 个 Transformer 解码器层之前进行拆分,这与上述手动拆分方式相呼应。同样地,在拆分完成后,我们可以通过调用 build_stage 来获取 PipelineStage

def tracer_model_split(model, example_input_microbatch) -> PipelineStage:
   pipe = pipeline(
      module=model,
      mb_args=(example_input_microbatch,),
      split_spec={
         "layers.4": SplitPoint.BEGINNING,
      }
   )
   stage = pipe.build_stage(stage_index, device, pp_group)
   return stage

第二步:定义主要执行逻辑#

在主函数中,我们将创建一个特定的流水线调度策略供各个阶段遵循。torch.distributed.pipelining 支持多种调度策略,包括单阶段-单 rank 调度(如 GPipe1F1B),以及多阶段-单 rank 调度(如 Interleaved1F1BLoopedBFS)。

if __name__ == "__main__":
   init_distributed()
   num_microbatches = 4
   model_args = ModelArgs()
   model = Transformer(model_args)

   # Dummy data
   x = torch.ones(32, 500, dtype=torch.long)
   y = torch.randint(0, model_args.vocab_size, (32, 500), dtype=torch.long)
   example_input_microbatch = x.chunk(num_microbatches)[0]

   # Option 1: Manual model splitting
   stage = manual_model_split(model)

   # Option 2: Tracer model splitting
   # stage = tracer_model_split(model, example_input_microbatch)

   model.to(device)
   x = x.to(device)
   y = y.to(device)

   def tokenwise_loss_fn(outputs, targets):
      loss_fn = nn.CrossEntropyLoss()
      outputs = outputs.reshape(-1, model_args.vocab_size)
      targets = targets.reshape(-1)
      return loss_fn(outputs, targets)

   schedule = ScheduleGPipe(stage, n_microbatches=num_microbatches, loss_fn=tokenwise_loss_fn)

   if rank == 0:
      schedule.step(x)
   elif rank == 1:
      losses = []
      output = schedule.step(target=y, losses=losses)
      print(f"losses: {losses}")
   dist.destroy_process_group()

在上面的示例中,我们使用了手动方法拆分模型,但可以将代码取消注释以尝试基于追踪器的模型拆分函数。在我们的调度中,需要传入微批次的数量以及用于评估目标的损失函数。

.step() 函数会处理整个小批次(minibatch),并根据先前传入的 n_microbatches 自动将其拆分为微批次。随后,微批次将按照调度类进行操作。在上述示例中,我们使用的是 GPipe,它遵循简单的“先全部前向、后全部反向”的调度顺序。从 rank 1 返回的输出将与模型在单个 GPU 上运行整个批次时的结果相同。同样,我们可以传入一个 losses 容器来存储每个微批次对应的损失值。

第三步:启动分布式进程#

最后,我们准备运行脚本。我们将使用 torchrun 创建一个单主机、2 进程的作业。我们的脚本编写方式已确保 rank 0 执行流水线阶段 0 所需的逻辑,而 rank 1 执行流水线阶段 1 的逻辑。

torchrun --nnodes 1 --nproc_per_node 2 pipelining_tutorial.py

结论#

在本教程中,我们学习了如何使用 PyTorch 的 torch.distributed.pipelining API 实现分布式流水线并行。我们探讨了环境设置、定义 Transformer 模型以及为分布式训练进行分区的方法。我们讨论了两种模型分区方法(手动和基于追踪器),并演示了如何跨不同阶段在微批次上调度计算。最后,我们介绍了流水线调度的执行,以及如何使用 torchrun 启动分布式进程。

其他资源#

我们已成功将 torch.distributed.pipelining 集成到 TorchTitan 存储库 中。TorchTitan 是一个用于使用原生 PyTorch 进行大规模 LLM 训练的简洁、极简的代码库。若要了解生产环境下的流水线并行使用方式以及与其他分布式技术的组合,请参阅 TorchTitan 的 3D 并行端到端示例