管道并行#
创建于:2025 年 6 月 16 日 | 最后更新于:2025 年 6 月 16 日
注意
torch.distributed.pipelining
目前处于 Alpha 阶段,正在开发中。API 可能会有变动。它从 PiPPy 项目迁移而来。
为什么选择管道并行?#
管道并行是深度学习的**原始**并行方式之一。它允许将模型**执行**进行分区,以便多个**微批次**可以同时执行模型的不同部分。管道并行是一种有效的技术,适用于
大规模训练
带宽受限的集群
大型模型推理
上述场景的共同点是,每个设备的计算无法掩盖传统并行(例如 FSDP 的权重全收集)的通信开销。
torch.distributed.pipelining
是什么?#
虽然管道化对于扩展很有前景,但通常难以实现,因为它除了模型权重之外还需要**划分模型的执行**。执行划分通常需要对模型进行侵入式代码更改。复杂性的另一个方面来自于**在分布式环境中调度微批次**,同时考虑到**数据流依赖**。
pipelining
包提供了一个工具包,可以**自动**完成上述操作,从而在**通用**模型上轻松实现管道并行。
它由两部分组成:一个**分割前端**和一个**分布式运行时**。分割前端接收您的模型代码,将其分割成“模型分区”,并捕获数据流关系。分布式运行时并行地在不同设备上执行管道阶段,处理微批次分割、调度、通信和梯度传播等。
总的来说,pipelining
包提供了以下功能:
根据简单规范分割模型代码。
丰富地支持管道调度,包括 GPipe、1F1B、交错式 1F1B 和循环 BFS,并提供编写自定义调度的基础设施。
对跨主机管道并行的第一类支持,因为这通常是 PP 使用的地方(通过较慢的互连)。
与其他PyTorch并行技术(如数据并行(DDP、FSDP)或张量并行)的可组合性。TorchTitan项目演示了Llama模型上的“3D并行”应用。
步骤 1:构建 PipelineStage
#
在使用 PipelineSchedule
之前,我们需要创建 PipelineStage
对象来包装在该阶段运行的模型部分。PipelineStage
负责分配通信缓冲区并创建发送/接收操作以与其对等方通信。它管理中间缓冲区,例如尚未消耗的前向输出,并提供了一个用于运行阶段模型反向的实用程序。
PipelineStage
需要知道阶段模型的输入和输出形状,以便正确分配通信缓冲区。形状必须是静态的,例如,在运行时形状不能逐步改变。如果运行时形状与预期形状不匹配,将引发 PipeliningShapeError
类错误。在与其他并行技术组合或应用混合精度时,必须考虑这些技术,以便 PipelineStage
知道阶段模块在运行时的正确形状(和数据类型)。
用户可以直接构造 PipelineStage
实例,通过传入一个表示模型应在该阶段运行部分的 nn.Module
。这可能需要更改原始模型代码。请参阅 选项 1:手动拆分模型 中的示例。
或者,拆分前端可以使用图分区自动将您的模型拆分为一系列 nn.Module
。此技术要求模型可使用 torch.Export
追踪。结果 nn.Module
与其他并行技术的组合性仍处于实验阶段,可能需要一些变通方法。如果用户不能轻松更改模型代码,使用此前端可能更具吸引力。有关更多信息,请参阅 选项 2:自动拆分模型。
步骤 2:使用 PipelineSchedule
执行#
我们现在可以将 PipelineStage
附加到流水线调度,并使用输入数据运行调度。以下是一个GPipe示例
from torch.distributed.pipelining import ScheduleGPipe
# Create a schedule
schedule = ScheduleGPipe(stage, n_microbatches)
# Input data (whole batch)
x = torch.randn(batch_size, in_dim, device=device)
# Run the pipeline with input `x`
# `x` will be divided into microbatches automatically
if rank == 0:
schedule.step(x)
else:
output = schedule.step()
请注意,上述代码需要为每个 worker 启动,因此我们使用启动服务来启动多个进程。
torchrun --nproc_per_node=2 example.py
模型拆分选项#
选项 1:手动拆分模型#
要直接构造 PipelineStage
,用户需要负责提供一个 nn.Module
实例,该实例拥有相关的 nn.Parameters
和 nn.Buffers
,并定义一个执行该阶段相关操作的 forward()
方法。例如,Torchtitan 中定义的 Transformer 类的精简版本展示了一种构建易于分区模型的模式。
class Transformer(nn.Module):
def __init__(self, model_args: ModelArgs):
super().__init__()
self.tok_embeddings = nn.Embedding(...)
# Using a ModuleDict lets us delete layers without affecting names,
# ensuring checkpoints will correctly save and load.
self.layers = torch.nn.ModuleDict()
for layer_id in range(model_args.n_layers):
self.layers[str(layer_id)] = TransformerBlock(...)
self.output = nn.Linear(...)
def forward(self, tokens: torch.Tensor):
# Handling layers being 'None' at runtime enables easy pipeline splitting
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
for layer in self.layers.values():
h = layer(h, self.freqs_cis)
h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
return output
以这种方式定义的模型可以通过以下方式轻松配置每个阶段:首先初始化整个模型(使用元设备以避免 OOM 错误),删除该阶段不需要的层,然后创建一个包装模型的 PipelineStage。例如
with torch.device("meta"):
assert num_stages == 2, "This is a simple 2-stage example"
# we construct the entire model, then delete the parts we do not need for this stage
# in practice, this can be done using a helper function that automatically divides up layers across stages.
model = Transformer()
if stage_index == 0:
# prepare the first stage model
del model.layers["1"]
model.norm = None
model.output = None
elif stage_index == 1:
# prepare the second stage model
model.tok_embeddings = None
del model.layers["0"]
from torch.distributed.pipelining import PipelineStage
stage = PipelineStage(
model,
stage_index,
num_stages,
device,
)
当与其他数据或模型并行技术组合时,如果模型块的输出形状/数据类型会受到影响,可能还需要 output_args
。
选项 2:自动拆分模型#
如果您有一个完整的模型,并且不想花时间将其修改为一系列“模型分区”,pipeline
API 可以帮助您。以下是一个简短示例
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.emb = torch.nn.Embedding(10, 3)
self.layers = torch.nn.ModuleList(
Layer() for _ in range(2)
)
self.lm = LMHead()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.emb(x)
for layer in self.layers:
x = layer(x)
x = self.lm(x)
return x
如果我们打印模型,可以看到多个层次结构,这使得手动拆分变得困难
Model(
(emb): Embedding(10, 3)
(layers): ModuleList(
(0-1): 2 x Layer(
(lin): Linear(in_features=3, out_features=3, bias=True)
)
)
(lm): LMHead(
(proj): Linear(in_features=3, out_features=3, bias=True)
)
)
让我们看看 pipeline
API 如何工作
from torch.distributed.pipelining import pipeline, SplitPoint
# An example micro-batch input
x = torch.LongTensor([1, 2, 4, 5])
pipe = pipeline(
module=mod,
mb_args=(x,),
split_spec={
"layers.1": SplitPoint.BEGINNING,
}
)
pipeline
API 根据 split_spec
拆分您的模型,其中 SplitPoint.BEGINNING
表示在 forward
函数中某个子模块执行*之前*添加拆分点,类似地,SplitPoint.END
表示在*之后*添加拆分点。
如果我们 print(pipe)
,我们可以看到
GraphModule(
(submod_0): GraphModule(
(emb): InterpreterModule()
(layers): Module(
(0): InterpreterModule(
(lin): InterpreterModule()
)
)
)
(submod_1): GraphModule(
(layers): Module(
(1): InterpreterModule(
(lin): InterpreterModule()
)
)
(lm): InterpreterModule(
(proj): InterpreterModule()
)
)
)
def forward(self, x):
submod_0 = self.submod_0(x); x = None
submod_1 = self.submod_1(submod_0); submod_0 = None
return (submod_1,)
“模型分区”由子模块(submod_0
、submod_1
)表示,每个子模块都使用原始模型操作、权重和层次结构进行重建。此外,还重建了一个“根级”forward
函数,以捕获这些分区之间的数据流。这种数据流稍后将由流水线运行时以分布式方式重放。
Pipe
对象提供了一个检索“模型分区”的方法
stage_mod : nn.Module = pipe.get_stage_module(stage_idx)
返回的 stage_mod
是一个 nn.Module
,您可以使用它创建优化器、保存或加载检查点,或应用其他并行化。
Pipe
还允许您在给定 ProcessGroup
的情况下在设备上创建分布式阶段运行时
stage = pipe.build_stage(stage_idx, device, group)
或者,如果您想在对 stage_mod
进行一些修改后构建阶段运行时,可以使用 build_stage
API 的函数版本。例如
from torch.distributed.pipelining import build_stage
from torch.nn.parallel import DistributedDataParallel
dp_mod = DistributedDataParallel(stage_mod)
info = pipe.info()
stage = build_stage(dp_mod, stage_idx, info, device, group)
注意
pipeline
前端使用跟踪器(torch.export
)将您的模型捕获到单个图中。如果您的模型无法进行完整图捕获,您可以使用我们下面的手动前端。
Hugging Face 示例#
在最初创建此包的 PiPPy 存储库中,我们保留了基于未修改的 Hugging Face 模型的示例。请参阅 examples/huggingface 目录。
示例包括
技术深入探究#
pipeline
API 如何拆分模型?#
首先,pipeline
API 通过跟踪模型将其转换为有向无环图(DAG)。它使用 torch.export
(一个 PyTorch 2 完整图捕获工具)跟踪模型。
然后,它将一个阶段所需的**操作和参数**组合到一个重构的子模块中:submod_0
、submod_1
等。
与传统子模块访问方法(如 Module.children()
)不同,pipeline
API 不仅切割模型的模块结构,还切割模型的 forward
函数。
这是必要的,因为像 Module.children()
这样的模型结构仅在 Module.__init__()
期间捕获信息,而不捕获任何关于 Module.forward()
的信息。换句话说,Module.children()
缺少对流水线化关键以下方面的信息
forward
中子模块的执行顺序子模块之间的激活流
子模块之间是否存在任何函数式运算符(例如,
relu
或add
操作将不会被Module.children()
捕获)。
相反,pipeline
API 确保 forward
行为真正得以保留。它还捕获分区之间的激活流,帮助分布式运行时无需人工干预即可进行正确的发送/接收调用。
pipeline
API 的另一个灵活性是拆分点可以在模型层次结构中的任意级别。在拆分分区中,与该分区相关的原始模型层次结构将无代价地重建。结果是,指向子模块或参数的完全限定名称(FQN)仍然有效,并且依赖 FQN 的服务(例如 FSDP、TP 或检查点)仍然可以以几乎零代码更改的方式与您分区的模块一起运行。
实现您自己的调度#
您可以通过扩展以下两个类之一来实现您自己的流水线调度
PipelineScheduleSingle
PipelineScheduleMulti
PipelineScheduleSingle
适用于每个 rank 仅分配*一个*阶段的调度。PipelineScheduleMulti
适用于每个 rank 分配多个阶段的调度。
例如,ScheduleGPipe
和 Schedule1F1B
是 PipelineScheduleSingle
的子类。而 ScheduleInterleaved1F1B
、ScheduleLoopedBFS
、ScheduleInterleavedZeroBubble
和 ScheduleZBVZeroBubble
是 PipelineScheduleMulti
的子类。
日志记录#
您可以使用 torch._logging 中的 TORCH_LOGS
环境变量开启附加日志记录
TORCH_LOGS=+pp
将显示logging.DEBUG
消息及其以上所有级别。TORCH_LOGS=pp
将显示logging.INFO
消息及其以上级别。TORCH_LOGS=-pp
将显示logging.WARNING
消息及其以上级别。
API 参考#
模型拆分 API#
以下 API 集将您的模型转换为流水线表示。
- class torch.distributed.pipelining.SplitPoint(value)[source]#
表示子模块执行中可能发生拆分点的枚举。:ivar BEGINNING: 表示在 forward 函数中某个子模块执行*之前*添加拆分点。:ivar END: 表示在 forward 函数中某个子模块执行*之后*添加拆分点。
- torch.distributed.pipelining.pipeline(module, mb_args, mb_kwargs=None, split_spec=None, split_policy=None)[source]#
根据规范拆分模块。
有关更多详细信息,请参阅 Pipe。
- 参数
module (Module) – 要拆分的模块。
mb_kwargs (Optional[dict[str, Any]]) – 关键字输入示例,采用微批次形式。(默认值:None)
split_spec (Optional[dict[str, torch.distributed.pipelining._IR.SplitPoint]]) – 使用子模块名称作为拆分标记的字典。(默认值:None)
split_policy (Optional[Callable[[GraphModule], GraphModule]]) – 用于拆分模块的策略。(默认值:None)
- 返回类型
类 Pipe 的流水线表示。
微批处理工具#
- torch.distributed.pipelining.microbatch.split_args_kwargs_into_chunks(args, kwargs, chunks, args_chunk_spec=None, kwargs_chunk_spec=None)[source]#
给定一系列 args 和 kwargs,根据它们各自的分块规范将其拆分为多个块。
- 参数
chunks (int) – 要将 args 和 kwargs 拆分成的块数
args_chunk_spec (Optional[tuple[torch.distributed.pipelining.microbatch.TensorChunkSpec, ...]]) – args 的分块规范,与 args 形状相同
kwargs_chunk_spec (Optional[dict[str, torch.distributed.pipelining.microbatch.TensorChunkSpec]]) – kwargs 的分块规范,与 kwargs 形状相同
- 返回
分片 args 列表 kwargs_split: 分片 kwargs 列表
- 返回类型
args_split
流水线阶段#
- class torch.distributed.pipelining.stage.PipelineStage(submodule, stage_index, num_stages, device, input_args=None, output_args=None, group=None, dw_builder=None)[source]#
表示流水线并行设置中流水线阶段的类。
PipelineStage 假设模型是顺序分区的,即模型被分成若干块,其中一个块的输出作为下一个块的输入,没有跳过连接。
PipelineStage 通过将输出从阶段 0 传播到阶段 1,依此类推,以线性顺序自动执行运行时形状/数据类型推断。要绕过形状推断,请将 input_args 和 output_args 传递给每个 PipelineStage 实例。
- 参数
submodule (nn.Module) – 此阶段包装的 PyTorch 模块。
stage_index (int) – 此阶段的 ID。
num_stages (int) – 阶段总数。
device (torch.device) – 此阶段所在的设备。
input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – 子模块的输入参数。
output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – 子模块的输出参数。
group (dist.ProcessGroup, optional) – 分布式训练的进程组。如果为 None,则使用默认组。
dw_builder (Optional[Callable[[], Callable[..., None]]) – 如果提供,dw_builder 将构建一个新的 dw_runner 函数,该函数将执行 F、I、W(前向、输入、权重)零气泡调度的 W 动作(输入权重)。
- torch.distributed.pipelining.stage.build_stage(stage_module, stage_index, pipe_info, device, group=None)[source]#
创建一个流水线阶段,给定要由该阶段包装的 stage_module 和流水线信息。
- 参数
stage_module (torch.nn.Module) – 要由该阶段包装的模块
stage_index (int) – 此阶段在流水线中的索引
pipe_info (PipeInfo) – 有关流水线的信息,可通过 pipe.info() 检索
device (torch.device) – 此阶段要使用的设备
group (Optional[dist.ProcessGroup]) – 此阶段要使用的进程组
- 返回
一个可与 PipelineSchedules 一起运行的流水线阶段。
- 返回类型
_PipelineStage
流水线调度#
- class torch.distributed.pipelining.schedules.ScheduleGPipe(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#
GPipe 调度。将以填充-排空方式遍历所有微批次。
- class torch.distributed.pipelining.schedules.Schedule1F1B(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#
1F1B 调度。将在稳态下对微批次执行一次前向和一次反向。
- class torch.distributed.pipelining.schedules.ScheduleInterleaved1F1B(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#
交错 1F1B 调度。有关详细信息,请参阅 https://arxiv.org/pdf/2104.04473。它将在稳态下对微批次执行一次前向和一次反向,并支持每个 rank 多个阶段。当微批次准备好用于多个局部阶段时,交错 1F1B 优先处理较早的微批次(也称为“深度优先”)。
此调度与原始论文基本相似。不同之处在于它放宽了 num_microbatch % pp_size == 0 的要求。使用 flex_pp 调度,我们将有 num_rounds = max(1, n_microbatches // pp_group_size),只要 n_microbatches % num_rounds 为 0,它就有效。举几个例子,支持
pp_group_size = 4,n_microbatches = 10。我们将有 num_rounds = 2,n_microbatches % 2 为 0。
pp_group_size = 4,n_microbatches = 3。我们将有 num_rounds = 1,n_microbatches % 1 为 0。
- class torch.distributed.pipelining.schedules.ScheduleLoopedBFS(stages, n_microbatches, loss_fn=None, output_merge_spec=None, scale_grads=True)[source]#
广度优先流水线并行。有关详细信息,请参阅 https://arxiv.org/abs/2211.05953。与交错 1F1B 类似,循环 BFS 支持每个 rank 多个阶段。不同之处在于,当微批次准备好用于多个局部阶段时,循环 BFS 将优先处理较早的阶段,一次运行所有可用的微批次。
- class torch.distributed.pipelining.schedules.ScheduleInterleavedZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#
交错零气泡调度。有关详细信息,请参阅 https://arxiv.org/pdf/2401.10241。它将在稳态下对微批次的输入执行一次前向和一次反向,并支持每个 rank 多个阶段。使用权重的反向传播来填充流水线气泡。
特别是,这实现了论文中的 ZB1P 调度。
- class torch.distributed.pipelining.schedules.ScheduleZBVZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#
零气泡调度(ZBV 变体)。有关详细信息,请参阅 https://arxiv.org/pdf/2401.10241 第 6 节。
此调度要求每个 rank 恰好有两个阶段。
此调度将在稳态下对微批次的输入执行一次前向和一次反向,并支持每个 rank 多个阶段。使用关于权重的反向传播来填充流水线气泡。
此 ZB-V 调度仅当正向时间 == 反向输入时间 == 反向权重时间时才具有“零气泡”特性。在实践中,对于实际模型而言,这不太可能成立,因此可以实现一个贪婪调度器来处理不相等/不平衡的时间。
- class torch.distributed.pipelining.schedules.PipelineScheduleSingle(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source]#
单阶段调度的基类。实现了 step 方法。派生类应实现 _step_microbatches。
根据 scale_grads 参数(默认为 True),梯度按 num_microbatches 进行缩放。此设置应与您的 loss_fn 配置匹配,loss_fn 可能平均损失(scale_grads=True)或求和损失(scale_grads=False)。
- class torch.distributed.pipelining.schedules.PipelineScheduleMulti(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, use_full_backward=None, scale_grads=True)[source]#
多阶段调度的基类。实现了 step 方法。
根据 scale_grads 参数(默认为 True),梯度按 num_microbatches 进行缩放。此设置应与您的 loss_fn 配置匹配,loss_fn 可能平均损失(scale_grads=True)或求和损失(scale_grads=False)。