评价此页

管道并行#

创建于:2025 年 6 月 16 日 | 最后更新于:2025 年 6 月 16 日

注意

torch.distributed.pipelining 目前处于 Alpha 阶段,正在开发中。API 可能会有变动。它从 PiPPy 项目迁移而来。

为什么选择管道并行?#

管道并行是深度学习的**原始**并行方式之一。它允许将模型**执行**进行分区,以便多个**微批次**可以同时执行模型的不同部分。管道并行是一种有效的技术,适用于

  • 大规模训练

  • 带宽受限的集群

  • 大型模型推理

上述场景的共同点是,每个设备的计算无法掩盖传统并行(例如 FSDP 的权重全收集)的通信开销。

torch.distributed.pipelining 是什么?#

虽然管道化对于扩展很有前景,但通常难以实现,因为它除了模型权重之外还需要**划分模型的执行**。执行划分通常需要对模型进行侵入式代码更改。复杂性的另一个方面来自于**在分布式环境中调度微批次**,同时考虑到**数据流依赖**。

pipelining 包提供了一个工具包,可以**自动**完成上述操作,从而在**通用**模型上轻松实现管道并行。

它由两部分组成:一个**分割前端**和一个**分布式运行时**。分割前端接收您的模型代码,将其分割成“模型分区”,并捕获数据流关系。分布式运行时并行地在不同设备上执行管道阶段,处理微批次分割、调度、通信和梯度传播等。

总的来说,pipelining 包提供了以下功能:

  • 根据简单规范分割模型代码。

  • 丰富地支持管道调度,包括 GPipe、1F1B、交错式 1F1B 和循环 BFS,并提供编写自定义调度的基础设施。

  • 对跨主机管道并行的第一类支持,因为这通常是 PP 使用的地方(通过较慢的互连)。

  • 与其他PyTorch并行技术(如数据并行(DDP、FSDP)或张量并行)的可组合性。TorchTitan项目演示了Llama模型上的“3D并行”应用。

步骤 1:构建 PipelineStage#

在使用 PipelineSchedule 之前,我们需要创建 PipelineStage 对象来包装在该阶段运行的模型部分。PipelineStage 负责分配通信缓冲区并创建发送/接收操作以与其对等方通信。它管理中间缓冲区,例如尚未消耗的前向输出,并提供了一个用于运行阶段模型反向的实用程序。

PipelineStage 需要知道阶段模型的输入和输出形状,以便正确分配通信缓冲区。形状必须是静态的,例如,在运行时形状不能逐步改变。如果运行时形状与预期形状不匹配,将引发 PipeliningShapeError 类错误。在与其他并行技术组合或应用混合精度时,必须考虑这些技术,以便 PipelineStage 知道阶段模块在运行时的正确形状(和数据类型)。

用户可以直接构造 PipelineStage 实例,通过传入一个表示模型应在该阶段运行部分的 nn.Module。这可能需要更改原始模型代码。请参阅 选项 1:手动拆分模型 中的示例。

或者,拆分前端可以使用图分区自动将您的模型拆分为一系列 nn.Module。此技术要求模型可使用 torch.Export 追踪。结果 nn.Module 与其他并行技术的组合性仍处于实验阶段,可能需要一些变通方法。如果用户不能轻松更改模型代码,使用此前端可能更具吸引力。有关更多信息,请参阅 选项 2:自动拆分模型

步骤 2:使用 PipelineSchedule 执行#

我们现在可以将 PipelineStage 附加到流水线调度,并使用输入数据运行调度。以下是一个GPipe示例

from torch.distributed.pipelining import ScheduleGPipe

# Create a schedule
schedule = ScheduleGPipe(stage, n_microbatches)

# Input data (whole batch)
x = torch.randn(batch_size, in_dim, device=device)

# Run the pipeline with input `x`
# `x` will be divided into microbatches automatically
if rank == 0:
    schedule.step(x)
else:
    output = schedule.step()

请注意,上述代码需要为每个 worker 启动,因此我们使用启动服务来启动多个进程。

torchrun --nproc_per_node=2 example.py

模型拆分选项#

选项 1:手动拆分模型#

要直接构造 PipelineStage,用户需要负责提供一个 nn.Module 实例,该实例拥有相关的 nn.Parametersnn.Buffers,并定义一个执行该阶段相关操作的 forward() 方法。例如,Torchtitan 中定义的 Transformer 类的精简版本展示了一种构建易于分区模型的模式。

class Transformer(nn.Module):
    def __init__(self, model_args: ModelArgs):
        super().__init__()

        self.tok_embeddings = nn.Embedding(...)

        # Using a ModuleDict lets us delete layers without affecting names,
        # ensuring checkpoints will correctly save and load.
        self.layers = torch.nn.ModuleDict()
        for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = TransformerBlock(...)

        self.output = nn.Linear(...)

    def forward(self, tokens: torch.Tensor):
        # Handling layers being 'None' at runtime enables easy pipeline splitting
        h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

        for layer in self.layers.values():
            h = layer(h, self.freqs_cis)

        h = self.norm(h) if self.norm else h
        output = self.output(h).float() if self.output else h
        return output

以这种方式定义的模型可以通过以下方式轻松配置每个阶段:首先初始化整个模型(使用元设备以避免 OOM 错误),删除该阶段不需要的层,然后创建一个包装模型的 PipelineStage。例如

with torch.device("meta"):
    assert num_stages == 2, "This is a simple 2-stage example"

    # we construct the entire model, then delete the parts we do not need for this stage
    # in practice, this can be done using a helper function that automatically divides up layers across stages.
    model = Transformer()

    if stage_index == 0:
        # prepare the first stage model
        del model.layers["1"]
        model.norm = None
        model.output = None

    elif stage_index == 1:
        # prepare the second stage model
        model.tok_embeddings = None
        del model.layers["0"]

    from torch.distributed.pipelining import PipelineStage
    stage = PipelineStage(
        model,
        stage_index,
        num_stages,
        device,
    )

当与其他数据或模型并行技术组合时,如果模型块的输出形状/数据类型会受到影响,可能还需要 output_args

选项 2:自动拆分模型#

如果您有一个完整的模型,并且不想花时间将其修改为一系列“模型分区”,pipeline API 可以帮助您。以下是一个简短示例

class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.emb = torch.nn.Embedding(10, 3)
        self.layers = torch.nn.ModuleList(
            Layer() for _ in range(2)
        )
        self.lm = LMHead()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.emb(x)
        for layer in self.layers:
            x = layer(x)
        x = self.lm(x)
        return x

如果我们打印模型,可以看到多个层次结构,这使得手动拆分变得困难

Model(
  (emb): Embedding(10, 3)
  (layers): ModuleList(
    (0-1): 2 x Layer(
      (lin): Linear(in_features=3, out_features=3, bias=True)
    )
  )
  (lm): LMHead(
    (proj): Linear(in_features=3, out_features=3, bias=True)
  )
)

让我们看看 pipeline API 如何工作

from torch.distributed.pipelining import pipeline, SplitPoint

# An example micro-batch input
x = torch.LongTensor([1, 2, 4, 5])

pipe = pipeline(
    module=mod,
    mb_args=(x,),
    split_spec={
        "layers.1": SplitPoint.BEGINNING,
    }
)

pipeline API 根据 split_spec 拆分您的模型,其中 SplitPoint.BEGINNING 表示在 forward 函数中某个子模块执行*之前*添加拆分点,类似地,SplitPoint.END 表示在*之后*添加拆分点。

如果我们 print(pipe),我们可以看到

GraphModule(
  (submod_0): GraphModule(
    (emb): InterpreterModule()
    (layers): Module(
      (0): InterpreterModule(
        (lin): InterpreterModule()
      )
    )
  )
  (submod_1): GraphModule(
    (layers): Module(
      (1): InterpreterModule(
        (lin): InterpreterModule()
      )
    )
    (lm): InterpreterModule(
      (proj): InterpreterModule()
    )
  )
)

def forward(self, x):
    submod_0 = self.submod_0(x);  x = None
    submod_1 = self.submod_1(submod_0);  submod_0 = None
    return (submod_1,)

“模型分区”由子模块(submod_0submod_1)表示,每个子模块都使用原始模型操作、权重和层次结构进行重建。此外,还重建了一个“根级”forward 函数,以捕获这些分区之间的数据流。这种数据流稍后将由流水线运行时以分布式方式重放。

Pipe 对象提供了一个检索“模型分区”的方法

stage_mod : nn.Module = pipe.get_stage_module(stage_idx)

返回的 stage_mod 是一个 nn.Module,您可以使用它创建优化器、保存或加载检查点,或应用其他并行化。

Pipe 还允许您在给定 ProcessGroup 的情况下在设备上创建分布式阶段运行时

stage = pipe.build_stage(stage_idx, device, group)

或者,如果您想在对 stage_mod 进行一些修改后构建阶段运行时,可以使用 build_stage API 的函数版本。例如

from torch.distributed.pipelining import build_stage
from torch.nn.parallel import DistributedDataParallel

dp_mod = DistributedDataParallel(stage_mod)
info = pipe.info()
stage = build_stage(dp_mod, stage_idx, info, device, group)

注意

pipeline 前端使用跟踪器(torch.export)将您的模型捕获到单个图中。如果您的模型无法进行完整图捕获,您可以使用我们下面的手动前端。

Hugging Face 示例#

在最初创建此包的 PiPPy 存储库中,我们保留了基于未修改的 Hugging Face 模型的示例。请参阅 examples/huggingface 目录。

示例包括

技术深入探究#

pipeline API 如何拆分模型?#

首先,pipeline API 通过跟踪模型将其转换为有向无环图(DAG)。它使用 torch.export(一个 PyTorch 2 完整图捕获工具)跟踪模型。

然后,它将一个阶段所需的**操作和参数**组合到一个重构的子模块中:submod_0submod_1 等。

与传统子模块访问方法(如 Module.children())不同,pipeline API 不仅切割模型的模块结构,还切割模型的 forward 函数。

这是必要的,因为像 Module.children() 这样的模型结构仅在 Module.__init__() 期间捕获信息,而不捕获任何关于 Module.forward() 的信息。换句话说,Module.children() 缺少对流水线化关键以下方面的信息

  • forward 中子模块的执行顺序

  • 子模块之间的激活流

  • 子模块之间是否存在任何函数式运算符(例如,reluadd 操作将不会被 Module.children() 捕获)。

相反,pipeline API 确保 forward 行为真正得以保留。它还捕获分区之间的激活流,帮助分布式运行时无需人工干预即可进行正确的发送/接收调用。

pipeline API 的另一个灵活性是拆分点可以在模型层次结构中的任意级别。在拆分分区中,与该分区相关的原始模型层次结构将无代价地重建。结果是,指向子模块或参数的完全限定名称(FQN)仍然有效,并且依赖 FQN 的服务(例如 FSDP、TP 或检查点)仍然可以以几乎零代码更改的方式与您分区的模块一起运行。

实现您自己的调度#

您可以通过扩展以下两个类之一来实现您自己的流水线调度

  • PipelineScheduleSingle

  • PipelineScheduleMulti

PipelineScheduleSingle 适用于每个 rank 仅分配*一个*阶段的调度。PipelineScheduleMulti 适用于每个 rank 分配多个阶段的调度。

例如,ScheduleGPipeSchedule1F1BPipelineScheduleSingle 的子类。而 ScheduleInterleaved1F1BScheduleLoopedBFSScheduleInterleavedZeroBubbleScheduleZBVZeroBubblePipelineScheduleMulti 的子类。

日志记录#

您可以使用 torch._logging 中的 TORCH_LOGS 环境变量开启附加日志记录

  • TORCH_LOGS=+pp 将显示 logging.DEBUG 消息及其以上所有级别。

  • TORCH_LOGS=pp 将显示 logging.INFO 消息及其以上级别。

  • TORCH_LOGS=-pp 将显示 logging.WARNING 消息及其以上级别。

API 参考#

模型拆分 API#

以下 API 集将您的模型转换为流水线表示。

class torch.distributed.pipelining.SplitPoint(value)[source]#

表示子模块执行中可能发生拆分点的枚举。:ivar BEGINNING: 表示在 forward 函数中某个子模块执行*之前*添加拆分点。:ivar END: 表示在 forward 函数中某个子模块执行*之后*添加拆分点。

torch.distributed.pipelining.pipeline(module, mb_args, mb_kwargs=None, split_spec=None, split_policy=None)[source]#

根据规范拆分模块。

有关更多详细信息,请参阅 Pipe

参数
返回类型

Pipe 的流水线表示。

class torch.distributed.pipelining.Pipe(split_gm, num_stages, has_loss_and_backward, loss_spec)[source]#
torch.distributed.pipelining.pipe_split()[source]#

pipe_split 是一个特殊的运算符,用于标记模块中阶段之间的边界。它用于将模块拆分为多个阶段。如果您的带注释模块是立即运行的,则它是一个无操作。

示例

>>> def forward(self, x):
>>>     x = torch.mm(x, self.mm_param)
>>>     x = torch.relu(x)
>>>     pipe_split()
>>>     x = self.lin(x)
>>>     return x

上面的示例将拆分为两个阶段。

微批处理工具#

class torch.distributed.pipelining.microbatch.TensorChunkSpec(split_dim)[source]#

用于指定输入分块的类

torch.distributed.pipelining.microbatch.split_args_kwargs_into_chunks(args, kwargs, chunks, args_chunk_spec=None, kwargs_chunk_spec=None)[source]#

给定一系列 args 和 kwargs,根据它们各自的分块规范将其拆分为多个块。

参数
返回

分片 args 列表 kwargs_split: 分片 kwargs 列表

返回类型

args_split

torch.distributed.pipelining.microbatch.merge_chunks(chunks, chunk_spec)[source]#

给定一个块列表,根据块规范将它们合并为一个值。

参数
  • chunks (list[Any]) – 块列表

  • chunk_spec – 块的分块规范

返回

合并值

返回类型

value

流水线阶段#

class torch.distributed.pipelining.stage.PipelineStage(submodule, stage_index, num_stages, device, input_args=None, output_args=None, group=None, dw_builder=None)[source]#

表示流水线并行设置中流水线阶段的类。

PipelineStage 假设模型是顺序分区的,即模型被分成若干块,其中一个块的输出作为下一个块的输入,没有跳过连接。

PipelineStage 通过将输出从阶段 0 传播到阶段 1,依此类推,以线性顺序自动执行运行时形状/数据类型推断。要绕过形状推断,请将 input_argsoutput_args 传递给每个 PipelineStage 实例。

参数
  • submodule (nn.Module) – 此阶段包装的 PyTorch 模块。

  • stage_index (int) – 此阶段的 ID。

  • num_stages (int) – 阶段总数。

  • device (torch.device) – 此阶段所在的设备。

  • input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – 子模块的输入参数。

  • output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – 子模块的输出参数。

  • group (dist.ProcessGroup, optional) – 分布式训练的进程组。如果为 None,则使用默认组。

  • dw_builder (Optional[Callable[[], Callable[..., None]]) – 如果提供,dw_builder 将构建一个新的 dw_runner 函数,该函数将执行 F、I、W(前向、输入、权重)零气泡调度的 W 动作(输入权重)。

torch.distributed.pipelining.stage.build_stage(stage_module, stage_index, pipe_info, device, group=None)[source]#

创建一个流水线阶段,给定要由该阶段包装的 stage_module 和流水线信息。

参数
  • stage_module (torch.nn.Module) – 要由该阶段包装的模块

  • stage_index (int) – 此阶段在流水线中的索引

  • pipe_info (PipeInfo) – 有关流水线的信息,可通过 pipe.info() 检索

  • device (torch.device) – 此阶段要使用的设备

  • group (Optional[dist.ProcessGroup]) – 此阶段要使用的进程组

返回

一个可与 PipelineSchedules 一起运行的流水线阶段。

返回类型

_PipelineStage

流水线调度#

class torch.distributed.pipelining.schedules.ScheduleGPipe(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#

GPipe 调度。将以填充-排空方式遍历所有微批次。

class torch.distributed.pipelining.schedules.Schedule1F1B(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#

1F1B 调度。将在稳态下对微批次执行一次前向和一次反向。

class torch.distributed.pipelining.schedules.ScheduleInterleaved1F1B(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#

交错 1F1B 调度。有关详细信息,请参阅 https://arxiv.org/pdf/2104.04473。它将在稳态下对微批次执行一次前向和一次反向,并支持每个 rank 多个阶段。当微批次准备好用于多个局部阶段时,交错 1F1B 优先处理较早的微批次(也称为“深度优先”)。

此调度与原始论文基本相似。不同之处在于它放宽了 num_microbatch % pp_size == 0 的要求。使用 flex_pp 调度,我们将有 num_rounds = max(1, n_microbatches // pp_group_size),只要 n_microbatches % num_rounds 为 0,它就有效。举几个例子,支持

  1. pp_group_size = 4,n_microbatches = 10。我们将有 num_rounds = 2,n_microbatches % 2 为 0。

  2. pp_group_size = 4,n_microbatches = 3。我们将有 num_rounds = 1,n_microbatches % 1 为 0。

class torch.distributed.pipelining.schedules.ScheduleLoopedBFS(stages, n_microbatches, loss_fn=None, output_merge_spec=None, scale_grads=True)[source]#

广度优先流水线并行。有关详细信息,请参阅 https://arxiv.org/abs/2211.05953。与交错 1F1B 类似,循环 BFS 支持每个 rank 多个阶段。不同之处在于,当微批次准备好用于多个局部阶段时,循环 BFS 将优先处理较早的阶段,一次运行所有可用的微批次。

class torch.distributed.pipelining.schedules.ScheduleInterleavedZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#

交错零气泡调度。有关详细信息,请参阅 https://arxiv.org/pdf/2401.10241。它将在稳态下对微批次的输入执行一次前向和一次反向,并支持每个 rank 多个阶段。使用权重的反向传播来填充流水线气泡。

特别是,这实现了论文中的 ZB1P 调度。

class torch.distributed.pipelining.schedules.ScheduleZBVZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#

零气泡调度(ZBV 变体)。有关详细信息,请参阅 https://arxiv.org/pdf/2401.10241 第 6 节。

此调度要求每个 rank 恰好有两个阶段。

此调度将在稳态下对微批次的输入执行一次前向和一次反向,并支持每个 rank 多个阶段。使用关于权重的反向传播来填充流水线气泡。

此 ZB-V 调度仅当正向时间 == 反向输入时间 == 反向权重时间时才具有“零气泡”特性。在实践中,对于实际模型而言,这不太可能成立,因此可以实现一个贪婪调度器来处理不相等/不平衡的时间。

class torch.distributed.pipelining.schedules.PipelineScheduleSingle(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#

单阶段调度的基类。实现了 step 方法。派生类应实现 _step_microbatches

根据 scale_grads 参数(默认为 True),梯度按 num_microbatches 进行缩放。此设置应与您的 loss_fn 配置匹配,loss_fn 可能平均损失(scale_grads=True)或求和损失(scale_grads=False)。

step(*args, target=None, losses=None, **kwargs)[source]#

使用*整批*输入运行流水线调度的一个迭代。将自动将输入分块为微批次,并根据调度实现遍历微批次。

args:模型的位置参数(与非流水线情况相同)。kwargs:模型的关键字参数(与非流水线情况相同)。target:损失函数的目标。losses:一个列表,用于存储每个微批次的损失。

class torch.distributed.pipelining.schedules.PipelineScheduleMulti(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, use_full_backward=None, scale_grads=True)[source]#

多阶段调度的基类。实现了 step 方法。

根据 scale_grads 参数(默认为 True),梯度按 num_microbatches 进行缩放。此设置应与您的 loss_fn 配置匹配,loss_fn 可能平均损失(scale_grads=True)或求和损失(scale_grads=False)。

step(*args, target=None, losses=None, **kwargs)[source]#

使用*整批*输入运行流水线调度的一个迭代。将自动将输入分块为微批次,并根据调度实现遍历微批次。

args:模型的位置参数(与非流水线情况相同)。kwargs:模型的关键字参数(与非流水线情况相同)。target:损失函数的目标。losses:一个列表,用于存储每个微批次的损失。