评价此页

分布式流水线并行简介#

创建日期:2024 年 7 月 9 日 | 最后更新:2024 年 12 月 12 日 | 最后验证:2024 年 11 月 5 日

作者Howard Huang

注意

editgithub 上查看和编辑此教程。

本教程使用一个类 GPT 的 transformer 模型来演示如何使用 torch.distributed.pipelining API 实现分布式流水线并行。

您将学到什么
  • 如何使用 torch.distributed.pipelining API

  • 如何将流水线并行应用于 transformer 模型

  • 如何对一组微批次利用不同的调度策略

先决条件

设置#

通过 torch.distributed.pipelining,我们将对模型的执行进行分区,并在微批次上调度计算。我们将使用一个简化的 transformer 解码器模型。该模型架构仅用于教学目的,包含多个 transformer 解码器层,以便我们演示如何将模型分割成不同的块。首先,让我们定义模型。

import torch
import torch.nn as nn
from dataclasses import dataclass

@dataclass
class ModelArgs:
   dim: int = 512
   n_layers: int = 8
   n_heads: int = 8
   vocab_size: int = 10000

class Transformer(nn.Module):
   def __init__(self, model_args: ModelArgs):
      super().__init__()

      self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)

      # Using a ModuleDict lets us delete layers witout affecting names,
      # ensuring checkpoints will correctly save and load.
      self.layers = torch.nn.ModuleDict()
      for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = nn.TransformerDecoderLayer(model_args.dim, model_args.n_heads)

      self.norm = nn.LayerNorm(model_args.dim)
      self.output = nn.Linear(model_args.dim, model_args.vocab_size)

   def forward(self, tokens: torch.Tensor):
      # Handling layers being 'None' at runtime enables easy pipeline splitting
      h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

      for layer in self.layers.values():
            h = layer(h, h)

      h = self.norm(h) if self.norm else h
      output = self.output(h).clone() if self.output else h
      return output

然后,我们需要在脚本中导入必要的库并初始化分布式训练过程。在这种情况下,我们定义了一些全局变量以供脚本稍后使用。

import os
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe

global rank, device, pp_group, stage_index, num_stages
def init_distributed():
   global rank, device, pp_group, stage_index, num_stages
   rank = int(os.environ["LOCAL_RANK"])
   world_size = int(os.environ["WORLD_SIZE"])
   device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
   dist.init_process_group()

   # This group can be a sub-group in the N-D parallel case
   pp_group = dist.new_group()
   stage_index = rank
   num_stages = world_size

通常在所有分布式程序中都会使用 rankworld_sizeinit_process_group() 代码。特定于流水线并行的全局变量包括 pp_group,它是用于 send/recv 通信的进程组;stage_index,在此示例中,每个 stage 只有一个 rank,因此索引等同于 rank;以及 num_stages,它等同于 world_size。

num_stages 用于设置流水线并行调度中使用的 stage 数量。例如,对于 num_stages=4,一个微批次需要经过 4 次前向传播和 4 次后向传播才能完成。框架需要 stage_index 来了解如何在 stage 之间进行通信。例如,对于第一个 stage(stage_index=0),它将使用来自数据加载器的数据,而无需从任何前驱 peer 接收数据即可执行其计算。

步骤 1:分割 Transformer 模型#

有两种不同的模型分割方法

第一种是手动模式,我们可以通过删除模型的部分属性来手动创建模型的两个实例。在此示例中,对于两个 stage(2 个 rank),模型被分成两半。

def manual_model_split(model) -> PipelineStage:
   if stage_index == 0:
      # prepare the first stage model
      for i in range(4, 8):
            del model.layers[str(i)]
      model.norm = None
      model.output = None

   elif stage_index == 1:
      # prepare the second stage model
      for i in range(4):
            del model.layers[str(i)]
      model.tok_embeddings = None

   stage = PipelineStage(
      model,
      stage_index,
      num_stages,
      device,
   )
   return stage

我们可以看到第一个 stage 没有 layer norm 或输出层,只包含前四个 transformer 块。第二个 stage 没有输入嵌入层,但包含输出层和最后的四个 transformer 块。然后该函数返回当前 rank 的 PipelineStage

第二种方法是基于跟踪器的模式,它根据 split_spec 参数自动分割模型。使用流水线规范,我们可以指示 torch.distributed.pipelining 在哪里分割模型。在下面的代码块中,我们在第 4 个 transformer 解码器层之前进行分割,这与上面描述的手动分割相呼应。同样,在完成分割后,可以通过调用 build_stage 来检索 PipelineStage

步骤 2:定义主执行#

在主函数中,我们将创建一个特定的流水线调度,stage 应遵循该调度。torch.distributed.pipelining 支持多种调度策略,包括单 stage-per-rank 调度 GPipe1F1B,以及多 stage-per-rank 调度,如 Interleaved1F1BLoopedBFS

if __name__ == "__main__":
   init_distributed()
   num_microbatches = 4
   model_args = ModelArgs()
   model = Transformer(model_args)

   # Dummy data
   x = torch.ones(32, 500, dtype=torch.long)
   y = torch.randint(0, model_args.vocab_size, (32, 500), dtype=torch.long)
   example_input_microbatch = x.chunk(num_microbatches)[0]

   # Option 1: Manual model splitting
   stage = manual_model_split(model)

   # Option 2: Tracer model splitting
   # stage = tracer_model_split(model, example_input_microbatch)

   model.to(device)
   x = x.to(device)
   y = y.to(device)

   def tokenwise_loss_fn(outputs, targets):
      loss_fn = nn.CrossEntropyLoss()
      outputs = outputs.reshape(-1, model_args.vocab_size)
      targets = targets.reshape(-1)
      return loss_fn(outputs, targets)

   schedule = ScheduleGPipe(stage, n_microbatches=num_microbatches, loss_fn=tokenwise_loss_fn)

   if rank == 0:
      schedule.step(x)
   elif rank == 1:
      losses = []
      output = schedule.step(target=y, losses=losses)
      print(f"losses: {losses}")
   dist.destroy_process_group()

在上面的示例中,我们使用了手动方法来分割模型,但可以取消注释该代码以尝试基于跟踪器的模型分割函数。在我们的调度中,我们需要传入微批次的数量以及用于评估目标的损失函数。

.step() 函数处理整个小批次,并根据之前传入的 n_microbatches 自动将其分割成微批次。然后根据调度类对微批次进行操作。在上面的示例中,我们使用的是 GPipe,它遵循简单的全前向传播然后全后向传播的调度。从 rank 1 返回的输出将与模型在单个 GPU 上运行整个批次时相同。类似地,我们可以传入一个 losses 容器来存储每个微批次对应的损失。

步骤 3:启动分布式进程#

最后,我们准备运行脚本。我们将使用 torchrun 来创建一个单主机、2 进程的任务。我们的脚本已经以 rank 0 执行流水线 stage 0 所需的逻辑,rank 1 执行流水线 stage 1 所需的逻辑的方式编写。

torchrun --nnodes 1 --nproc_per_node 2 pipelining_tutorial.py

结论#

在本教程中,我们学习了如何使用 PyTorch 的 torch.distributed.pipelining API 来实现分布式流水线并行。我们探索了环境的设置、 transformer 模型的定义以及为分布式训练对其进行的分割。我们讨论了两种模型分割方法:手动和基于跟踪器,并演示了如何在不同 stage 上调度微批次的计算。最后,我们介绍了流水线调度的执行以及使用 torchrun 启动分布式进程。

附加资源#

我们已成功将 torch.distributed.pipelining 集成到 torchtitan 存储库中。TorchTitan 是一个干净、最小的代码库,用于使用原生 PyTorch 进行大规模 LLM 训练。有关生产就绪的流水线并行用法以及与其他分布式技术的组合,请参阅 TorchTitan 的 3D 并行端到端示例