Tensor Parallelism - torch.distributed.tensor.parallel#
Created On: Jun 13, 2025 | Last Updated On: Sep 09, 2025
Tensor Parallelism (TP) 基于 PyTorch DistributedTensor (DTensor) 构建,并提供不同的并行化风格:列式并行 (Colwise)、行式并行 (Rowwise) 和序列并行 (Sequence Parallelism)。
警告
Tensor Parallelism API 处于实验阶段,可能会发生更改。
使用 Tensor Parallelism 并行化您的 nn.Module 的入口点是
- torch.distributed.tensor.parallel.parallelize_module(module, device_mesh=None, parallelize_plan=None, *, src_data_rank=0)[source]#
通过根据用户指定的计划并行化模块或子模块,在 PyTorch 中应用 Tensor Parallelism。
我们根据 `parallelize_plan` 来并行化模块或子模块。`parallelize_plan` 包含
ParallelStyle,它指示用户希望如何并行化模块或子模块。用户还可以为每个模块的完全限定名 (FQN) 指定不同的并行化风格。
请注意,`parallelize_module` 只接受一维
DeviceMesh。如果您有一个二维或 N 维DeviceMesh,请先将其切片到一个一维子DeviceMesh,然后再传递给此 API (例如,`device_mesh["tp"]`)- 参数:
module (
nn.Module) – 要并行化的模块。device_mesh (
DeviceMesh, optional) – 描述 DTensor 设备网格拓扑的对象。如果未指定,调用必须在 DeviceMesh 上下文中。parallelize_plan (Union[
ParallelStyle, Dict[str,ParallelStyle]], optional) – 用于并行化模块的计划。它可以是一个ParallelStyle对象,其中包含我们为 Tensor Parallelism 准备输入/输出的方式,也可以是模块 FQN 及其对应的ParallelStyle对象的字典。如果未指定,该调用目前将不执行任何操作。
- 关键字参数:
src_data_rank (int, optional) – 逻辑/全局张量的源数据排名,它由 `distribute_tensor()` 用于将分片/副本分散/广播到其他排名。默认情况下,我们在每个 DeviceMesh 维上使用 `group_rank=0` 作为源数据,以保留单设备语义。如果显式传递 `None`,`parallelize_module()` 将仅使用其本地数据,而不是尝试通过 scatter/broadcast 来保留单设备语义。默认值:0
- 返回:
一个并行化后的
nn.Module对象。- 返回类型:
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> >>> # Define the module. >>> m = Model(...) >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()}) >>>
注意
对于复杂的模块架构,如 Attention、MLP 层,我们建议将不同的 `ParallelStyle` 组合在一起(例如,
ColwiseParallel和RowwiseParallel),并将其作为 `parallelize_plan` 传递,以实现所需的切分计算。
Tensor Parallelism 支持以下并行化风格
- class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]#
以列式方式对兼容的 `nn.Module` 进行分区。目前支持 `nn.Linear` 和 `nn.Embedding`。用户可以将其与 `RowwiseParallel` 组合以实现更复杂模块(例如 MLP、Attention)的分区。
- 关键字参数:
input_layouts (Placement, optional) – `nn.Module` 的输入张量的 DTensor 布局,用于将输入张量注解为 DTensor。如果未指定,我们假设输入张量是复制的。
output_layouts (Placement, optional) – `nn.Module` 输出的 DTensor 布局,用于确保 `nn.Module` 的输出具有用户期望的布局。如果未指定,输出张量将在最后一个维度上分片。
use_local_output (bool, optional) – 是否使用本地 `torch.Tensor` 而不是 `DTensor` 作为模块输出,默认值:True。
- 返回:
一个表示 `nn.Module` 列式分片的 `ParallelStyle` 对象。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()}) >>> ...
注意
默认情况下,如果未指定 `output_layouts`,`ColwiseParallel` 的输出将在最后一个维度上分片。如果存在需要特定张量形状的操作(例如,在配对的 `RowwiseParallel` 之前),请记住,如果输出已分片,该操作可能需要调整以适应分片大小。
- class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]#
以行式方式对兼容的 `nn.Module` 进行分区。目前支持 `nn.Linear` 和 `nn.Embedding`。用户可以将其与 `ColwiseParallel` 组合以实现更复杂模块(例如 MLP、Attention)的分区。
- 关键字参数:
input_layouts (Placement, optional) – `nn.Module` 的输入张量的 DTensor 布局,用于将输入张量注解为 DTensor。如果未指定,我们假设输入张量在最后一个维度上分片。
output_layouts (Placement, optional) – `nn.Module` 输出的 DTensor 布局,用于确保 `nn.Module` 的输出具有用户期望的布局。如果未指定,输出张量将被复制。
use_local_output (bool, optional) – 是否使用本地 `torch.Tensor` 而不是 `DTensor` 作为模块输出,默认值:True。
- 返回:
一个表示 `nn.Module` 行式分片的 `ParallelStyle` 对象。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), >>> ...
- class torch.distributed.tensor.parallel.SequenceParallel(*, sequence_dim=1, use_local_output=False)[source]#
SequenceParallel 复制兼容的 `nn.Module` 参数,并使用在序列维度上分片的输入进行分片计算。这目前支持 `nn.LayerNorm`、`nn.Dropout` 以及 RMSNorm Python 实现。
此风格实现了论文 Reducing Activation Recomputation in Large Transformer Models 中描述的操作。
如果传递给此 `nn.Module` 的输入是 `torch.Tensor`,则假定输入已在序列维度上分片,并将输入转换为在序列维度上分片的 `DTensor`。如果传递给此 `nn.Module` 的输入已经是 `DTensor` 但未在序列维度上分片,它将重新分布输入以在序列维度上分片。
该 `nn.Module` 的输出将在序列维度上分片。
- 关键字参数:
- 返回:
一个表示 `nn.Module` 序列并行的 `ParallelStyle` 对象。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), >>> ...
注意
SequenceParallel 风格假定 `nn.Module` 中的权重(例如 `nn.LayerNorm` 或 `RMSNorm`,它们默认具有全一初始化)的初始化是全一的。如果您对这些模块上的权重有自定义初始化,您需要在并行化之前/之后广播权重,以确保它们被复制。
要仅在调用 `parallelize_module` 时,使用 `parallelize_plan` 中的以下 `ParallelStyle` 来配置 `nn.Module` 的输入和输出的 DTensor 布局并执行必要的布局重分布,而不分配模块参数到 DTensor:
- class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_output=False)[source]#
根据 `input_layouts` 在运行时配置 `nn.Module` 的输入,将 `nn.Module` 的输入张量转换为 DTensor,并根据 `desired_input_layouts` 执行布局重分布。
- 关键字参数:
input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 的输入张量的 DTensor 布局,用于将输入张量转换为 DTensor。如果某些输入不是 `torch.Tensor` 或不需要转换为 DTensor,则需要指定 `None` 作为占位符。默认值:None。
desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 的输入张量的期望 DTensor 布局,用于确保 `nn.Module` 的输入具有期望的 DTensor 布局。此参数需要与 `input_layouts` 的长度相同。默认值:None。
input_kwarg_layouts (Dict[str, Placement]) – `nn.Module` 的输入 kwargs 的 DTensor 布局,用于将输入 kwargs 张量转换为 DTensor。默认值:None
desired_input_kwarg_layouts – (Dict[str, Placement]): `nn.Module` 的输入 kwargs 的期望 DTensor 布局,用于确保 `nn.Module` 的输入具有期望的 DTensor 布局。默认值:None。
use_local_output (bool, optional) – 是否使用本地 `torch.Tensor` 而不是 `DTensor` 作为模块输入,默认值:False。
- 返回:
一个准备 `nn.Module` 输入的分片布局的 `ParallelStyle` 对象。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor >>> # and then redistributed to Replicated DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan={ >>> "attn": PrepareModuleInput( >>> input_layouts=(Shard(0), None, None, ...), >>> desired_input_layouts=(Replicate(), None, None, ...) >>> ), >>> } >>> )
- class torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)[source]#
根据 `output_layouts` 在运行时配置 `nn.Module` 的输出,将 `nn.Module` 的输出张量转换为 DTensor,并根据 `desired_output_layouts` 执行布局重分布。
- 关键字参数:
output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 的输出张量的 DTensor 布局,用于将输出张量转换为 DTensor(如果它们是 `torch.Tensor`)。如果某些输出不是 `torch.Tensor` 或不需要转换为 DTensor,则需要指定 `None` 作为占位符。
desired_output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 的输出张量的期望 DTensor 布局,用于确保 `nn.Module` 的输出具有期望的 DTensor 布局。
use_local_output (bool, optional) – 是否使用本地 `torch.Tensor` 而不是 `DTensor` 作为模块输出,默认值:True。
- 返回:
一个准备 `nn.Module` 输出的分片布局的 `ParallelStyle` 对象。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor >>> # and then redistributed to Sharded DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan = PrepareModuleOutput( >>> output_layouts=Replicate(), >>> desired_output_layouts=Shard(0) >>> ) >>> )
- class torch.distributed.tensor.parallel.PrepareModuleInputOutput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_input=False, output_layouts, desired_output_layouts, use_local_output=True)[source]#
根据 `input_layouts`(以及 `output_layouts`,分别)在运行时配置 `nn.Module` 的输入(以及输出),将 `nn.Module` 的输入张量(以及输出张量,分别)转换为 DTensor,并根据 `desired_input_layouts`(以及 `desired_output_layouts`,分别)执行布局重分布。这是 `PrepareModuleInput` 和 `PrepareModuleOutput` 的组合。
- 关键字参数:
input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 的输入张量的 DTensor 布局,用于将输入张量转换为 DTensor。如果某些输入不是 `torch.Tensor` 或不需要转换为 DTensor,则需要指定 `None` 作为占位符。默认值:None。
desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 的输入张量的期望 DTensor 布局,用于确保 `nn.Module` 的输入具有期望的 DTensor 布局。此参数需要与 `input_layouts` 的长度相同。默认值:None。
input_kwarg_layouts (Dict[str, Placement]) – `nn.Module` 的输入 kwargs 的 DTensor 布局,用于将输入 kwargs 张量转换为 DTensor。默认值:None
desired_input_kwarg_layouts – (Dict[str, Placement]): `nn.Module` 的输入 kwargs 的期望 DTensor 布局,用于确保 `nn.Module` 的输入具有期望的 DTensor 布局。默认值:None。
use_local_input (bool, optional) – 是否使用本地 `torch.Tensor` 而不是 `DTensor` 作为模块输入,默认值:False。
output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 的输出张量的 DTensor 布局,用于将输出张量转换为 DTensor(如果它们是 `torch.Tensor`)。如果某些输出不是 `torch.Tensor` 或不需要转换为 DTensor,则需要指定 `None` 作为占位符。
desired_output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 的输出张量的期望 DTensor 布局,用于确保 `nn.Module` 的输出具有期望的 DTensor 布局。
use_local_output (bool, optional) – 是否使用本地 `torch.Tensor` 而不是 `DTensor` 作为模块输出,默认值:True。
- 返回:
一个准备 `nn.Module` 输入和输出的分片布局的 `ParallelStyle` 对象。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInputOutput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the first input of attn will be annotated as Sharded DTensor >>> # and then redistributed to Replicated DTensor, and the output of the TransformerBlock will be annotated >>> # as Replicated DTensor and then redistributed to Sharded DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan={ >>> "attn": PrepareModuleInputOutput( >>> input_layouts=(Shard(0), None, None, ...), >>> desired_input_layouts=(Replicate(), None, None, ...), >>> output_layouts=Replicate(), >>> desired_output_layouts=Shard(0), >>> ), >>> } >>> )
注意
当使用上述 `ParallelStyle` 的 `input/output_layouts` 作为 `Shard(dim)` 时,我们假设输入/输出激活张量在 TP 操作所在的 `DeviceMesh` 上的 `dim` 维度上是均匀分片的。例如,由于 `RowwiseParallel` 接受在最后一个维度上分片的输入,它假设输入张量已经过最后一个维度的均匀分片。对于非均匀分片激活张量的情况,可以先将 DTensor 直接传递给分区模块,并使用 `use_local_output=False` 在每个 `ParallelStyle` 后返回 DTensor,DTensor 可以跟踪非均匀分片信息。
对于 Transformer 等模型,我们建议用户在 `parallelize_plan` 中同时使用 `ColwiseParallel` 和 `RowwiseParallel`,以实现整个模型(例如 Attention 和 MLP)所需的切分。
并行化交叉熵损失计算(损失并行)通过以下上下文管理器支持
- torch.distributed.tensor.parallel.loss_parallel()[source]#
一个启用损失并行的上下文管理器,当输入在类别维度上分片时,可以执行高效的并行化损失计算。目前仅支持交叉熵损失。
在此上下文管理器中,您可以像平常一样使用 `cross_entropy()` 或 `CrossEntropyLoss`,并对输入参数进行以下假设。相应的 `backward()` 调用(如果存在)也需要在此上下文管理器下进行。
- 参数:
input (
DTensor) – 输入 logits。假设在类别维度上分片。target (Union[
torch.Tensor,DTensor]) – 必须是真实类索引(目前不支持类概率)。假设在 `DeviceMesh` 上复制。weight (Union[
torch.Tensor,DTensor], optional) – 如果提供,假设在 `DeviceMesh` 上复制。label_smoothing – 目前不支持。
- 返回:
一个复制的 `DTensor`。
示例
此处手动创建一个分片的 DTensor 来演示用法。实际上,它通常是 TP 模块的输出。
>>> from torch.distributed.tensor.parallel import loss_parallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> device_mesh = init_device_mesh("cuda", (8,)) >>> input = torch.randn(4, 16, device="cuda", requires_grad=True) >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)]) >>> target = torch.randint(16, (4,), device="cuda") >>> with loss_parallel(): >>> loss = F.cross_entropy(dist_input, target, reduction="mean") >>> loss.backward() >>> ...
警告
The loss_parallel API is experimental and subject to change.