评价此页

Tensor 并行 - torch.distributed.tensor.parallel#

创建于: 2025年6月13日 | 最后更新于: 2025年6月13日

Tensor 并行 (TP) 构建在 PyTorch DistributedTensor (DTensor)[https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md] 之上,并提供不同的并行样式:逐列 (Colwise)、逐行 (Rowwise) 和序列并行 (Sequence Parallelism)。

警告

Tensor 并行 API 处于实验阶段,可能会发生更改。

使用 Tensor 并行来并行化您的 nn.Module 的入口点是:

torch.distributed.tensor.parallel.parallelize_module(module, device_mesh=None, parallelize_plan=None, *, src_data_rank=0)[source]#

通过根据用户指定的计划并行化模块或子模块,在 PyTorch 中应用 Tensor 并行。

我们根据 `parallelize_plan` 来并行化模块或子模块。`parallelize_plan` 包含 `ParallelStyle`,它指示用户希望模块或子模块如何被并行化。

用户还可以为每个模块的完全限定名 (FQN) 指定不同的并行样式。

请注意,parallelize_module 只接受一维 DeviceMesh。如果您有一个二维或 N 维 DeviceMesh,请先将其切片到一个一维子 DeviceMesh,然后再传递给此 API (例如 device_mesh["tp"])。

参数
  • module (nn.Module) – 要并行化的模块。

  • device_mesh (DeviceMesh, optional) – 描述 DTensor 设备网格拓扑的对象。如果未指定,调用必须在 `DeviceMesh` 上下文中进行。

  • parallelize_plan (Union[ParallelStyle, Dict[str, ParallelStyle]], optional) – 用于并行化模块的计划。它可以是包含我们如何准备 Tensor 并行输入的/输出的 `ParallelStyle` 对象,也可以是模块 FQN 及其对应的 `ParallelStyle` 对象的字典。如果未指定,调用将不做任何操作。

关键字参数

src_data_rank (int, optional) – 逻辑/全局张量的源数据的 rank,它由 distribute_tensor() 用于将分片/副本散布/广播到其他 rank。默认情况下,我们在每个 `DeviceMesh` 维度上使用 `group_rank=0` 作为源数据,以保留单设备语义。如果显式传递 None,则 parallelize_module() 仅使用其本地数据,而不是尝试通过散布/广播来保留单设备语义。默认值:0

返回

一个已并行化的 nn.Module 对象。

返回类型

模块

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>>
>>> # Define the module.
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>>

注意

对于复杂的模块架构,如 Attention、MLP 层,我们建议组合不同的 ParallelStyle (例如 ColwiseParallelRowwiseParallel) 并将它们作为 `parallelize_plan` 传递,以实现所需的 Sharding 计算。

Tensor 并行支持以下并行样式:

class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]#

以逐列 (column-wise) 的方式划分兼容的 `nn.Module`。目前支持 `nn.Linear` 和 `nn.Embedding`。用户可以将其与 `RowwiseParallel` 组合以实现更复杂的模块 (例如 MLP、Attention) 的划分。

关键字参数
  • input_layouts (Placement, optional) – `nn.Module` 的输入张量的 DTensor 布局,用于注解输入张量以使其成为 DTensor。如果未指定,我们假设输入张量是复制的。

  • output_layouts (Placement, optional) – `nn.Module` 输出的 DTensor 布局,用于确保 `nn.Module` 的输出具有用户期望的布局。如果未指定,输出张量将在最后一个维度上进行分片。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 `DTensor` 作为模块输出,默认为 True。

返回

一个表示 `nn.Module` 逐列分片的 `ParallelStyle` 对象。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w1" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor
>>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
>>> ...

注意

默认情况下,如果未指定 `output_layouts`,`ColwiseParallel` 的输出将在最后一个维度上进行分片。如果存在需要特定张量形状的运算符 (例如,在配对的 `RowwiseParallel` 之前),请注意,如果输出被分片,则该运算符可能需要调整以适应分片后的大小。

class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]#

以逐行 (row-wise) 的方式划分兼容的 `nn.Module`。目前支持 `nn.Linear` 和 `nn.Embedding`。用户可以将其与 `ColwiseParallel` 组合以实现更复杂的模块 (例如 MLP、Attention) 的划分。

关键字参数
  • input_layouts (Placement, optional) – `nn.Module` 的输入张量的 DTensor 布局,用于注解输入张量以使其成为 DTensor。如果未指定,我们假设输入张量在最后一个维度上进行分片。

  • output_layouts (Placement, optional) – `nn.Module` 输出的 DTensor 布局,用于确保 `nn.Module` 的输出具有用户期望的布局。如果未指定,输出张量将被复制。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 `DTensor` 作为模块输出,默认为 True。

返回

一个表示 `nn.Module` 逐行分片的 `ParallelStyle` 对象。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w2" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim
>>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}),
>>> ...
class torch.distributed.tensor.parallel.SequenceParallel(*, sequence_dim=1, use_local_output=False)[source]#

SequenceParallel 复制兼容的 nn.Module 参数,并使用在序列维度上分片的输入进行分片计算。这目前支持 nn.LayerNormnn.Dropout 以及 RMSNorm Python 实现

此样式实现了论文 Reducing Activation Recomputation in Large Transformer Models 中描述的操作。

如果传递给此 nn.Module 的输入是 torch.Tensor,它会假设输入已在序列维度上分片,并将输入转换为在序列维度上分片的 DTensor。如果传递给此 nn.Module 的输入已经是 DTensor 但未在序列维度上分片,它将重新分发输入以在序列维度上分片。

`nn.Module` 的输出将在序列维度上进行分片。

关键字参数
  • sequence_dim (int, optional) – `nn.Module` 的输入张量的序列维度,用于注解输入张量以使其成为在序列维度上分片的 DTensor,默认值为 1。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作为模块输出,默认值为 False。

返回

一个表示 `nn.Module` 序列并行的 `ParallelStyle` 对象。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
>>> ...

注意

SequenceParallel 样式假设 `nn.Module` 中的权重 (例如 nn.LayerNormRMSNorm) 具有全1初始化 (默认情况下)。如果您对这些模块的权重有自定义初始化,您需要在并行化之前/之后广播权重,以确保它们被复制。

要仅使用 `parallelize_module` 调用中的 `parallelize_plan` 来配置 `nn.Module` 的输入和输出的 DTensor 布局并执行必要的布局重分布,而不分发模块参数到 DTensor,可以使用以下 `ParallelStyle`:

class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_output=False)[source]#

根据 `input_layouts` 配置 `nn.Module` 的输入,在运行时将 `nn.Module` 的输入张量转换为 DTensor,并根据 `desired_input_layouts` 进行布局重分布。

关键字参数
  • input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 的输入张量的 DTensor 布局,用于将输入张量转换为 DTensor。如果某些输入不是 `torch.Tensor` 或不需要转换为 DTensor,则需要指定 `None` 作为占位符。默认为 None。

  • desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 输入张量的期望 DTensor 布局,用于确保 `nn.Module` 的输入具有期望的 DTensor 布局。此参数的长度必须与 `input_layouts` 相同。默认为 None。

  • input_kwarg_layouts (Dict[str, Placement]) – `nn.Module` 的输入 kwargs 的 DTensor 布局,用于将输入 kwargs 张量转换为 DTensor。默认为 None。

  • desired_input_kwarg_layouts – (Dict[str, Placement]): `nn.Module` 的输入 kwargs 的期望 DTensor 布局,用于确保 `nn.Module` 的输入具有期望的 DTensor 布局。默认为 None。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作为模块输入,默认为 False。

返回

一个 `ParallelStyle` 对象,用于准备 `nn.Module` 输入的分片布局。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
>>> # and then redistributed to Replicated DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh,
>>>     parallelize_plan={
>>>         "attn": PrepareModuleInput(
>>>             input_layouts=(Shard(0), None, None, ...),
>>>             desired_input_layouts=(Replicate(), None, None, ...)
>>>         ),
>>>     }
>>> )
class torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)[source]#

根据 `output_layouts` 配置 `nn.Module` 的输出,在运行时将 `nn.Module` 的输出张量转换为 DTensor,并根据 `desired_output_layouts` 进行布局重分布。

关键字参数
  • output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 的输出张量的 DTensor 布局,用于将输出张量转换为 DTensor (如果它们是 torch.Tensor)。如果某些输出不是 `torch.Tensor` 或不需要转换为 DTensor,则需要指定 `None` 作为占位符。

  • desired_output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 输出张量的期望 DTensor 布局,用于确保 `nn.Module` 的输出具有期望的 DTensor 布局。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作为模块输出,默认为 True。

返回

一个 `ParallelStyle` 对象,用于准备 `nn.Module` 输出的分片布局。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor
>>> # and then redistributed to Sharded DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh,
>>>     parallelize_plan = PrepareModuleOutput(
>>>         output_layouts=Replicate(),
>>>         desired_output_layouts=Shard(0)
>>>     )
>>> )
class torch.distributed.tensor.parallel.PrepareModuleInputOutput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_input=False, output_layouts, desired_output_layouts, use_local_output=True)[source]#

配置 `nn.Module` 的输入 (和输出),在运行时根据 `input_layouts` (和 `output_layouts`,分别) 将 `nn.Module` 的输入张量 (和输出张量,分别) 转换为 DTensor,并根据 `desired_input_layouts` (和 `desired_output_layouts`,分别) 进行布局重分布。这是 `PrepareModuleInput` 和 `PrepareModuleOutput` 的组合。

关键字参数
  • input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 的输入张量的 DTensor 布局,用于将输入张量转换为 DTensor。如果某些输入不是 `torch.Tensor` 或不需要转换为 DTensor,则需要指定 `None` 作为占位符。默认为 None。

  • desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 输入张量的期望 DTensor 布局,用于确保 `nn.Module` 的输入具有期望的 DTensor 布局。此参数的长度必须与 `input_layouts` 相同。默认为 None。

  • input_kwarg_layouts (Dict[str, Placement]) – `nn.Module` 的输入 kwargs 的 DTensor 布局,用于将输入 kwargs 张量转换为 DTensor。默认为 None。

  • desired_input_kwarg_layouts – (Dict[str, Placement]): `nn.Module` 的输入 kwargs 的期望 DTensor 布局,用于确保 `nn.Module` 的输入具有期望的 DTensor 布局。默认为 None。

  • use_local_input (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作为模块输入,默认为 False。

  • output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 的输出张量的 DTensor 布局,用于将输出张量转换为 DTensor (如果它们是 torch.Tensor)。如果某些输出不是 `torch.Tensor` 或不需要转换为 DTensor,则需要指定 `None` 作为占位符。

  • desired_output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 输出张量的期望 DTensor 布局,用于确保 `nn.Module` 的输出具有期望的 DTensor 布局。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作为模块输出,默认为 True。

返回

一个 `ParallelStyle` 对象,用于准备 `nn.Module` 输入和输出的分片布局。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInputOutput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the first input of attn will be annotated as Sharded DTensor
>>> # and then redistributed to Replicated DTensor, and the output of the TransformerBlock will be annotated
>>> # as Replicated DTensor and then redistributed to Sharded DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh,
>>>     parallelize_plan={
>>>         "attn": PrepareModuleInputOutput(
>>>             input_layouts=(Shard(0), None, None, ...),
>>>             desired_input_layouts=(Replicate(), None, None, ...),
>>>             output_layouts=Replicate(),
>>>             desired_output_layouts=Shard(0),
>>>         ),
>>>     }
>>> )

注意

当使用 `Shard(dim)` 作为上述 `ParallelStyle` 的输入/输出布局时,我们假设输入/输出激活张量在 TP 操作的 `DeviceMesh` 上均匀分片于张量维度 `dim`。例如,由于 `RowwiseParallel` 接受在最后一个维度上分片的输入,它假设输入张量已经均匀分片于最后一个维度。对于不均匀分片的激活张量,可以先将 DTensor 直接传递给分片后的模块,并使用 `use_local_output=False` 在每个 `ParallelStyle` 后返回 DTensor,DTensor 可以跟踪不均匀分片信息。

对于 Transformer 等模型,我们建议用户在 `parallelize_plan` 中一起使用 `ColwiseParallel` 和 `RowwiseParallel`,以实现整个模型 (例如 Attention 和 MLP) 的期望分片。

可以通过以下上下文管理器支持并行化的交叉熵损失计算 (损失并行):

torch.distributed.tensor.parallel.loss_parallel()[source]#

一个启用损失并行的上下文管理器,当输入在类别维度上分片时,可以执行高效的并行化损失计算。目前仅支持交叉熵损失。

在此上下文管理器中,您可以像平常一样使用 `cross_entropy()` 或 `CrossEntropyLoss`,并满足以下输入参数假设。相应的 backward() 调用 (如果有) 也需要在此上下文管理器下进行。

参数
  • input (DTensor) – 输入 logits。假定在类别维度上分片。

  • target (Union[torch.Tensor, DTensor]) – 必须是真实类索引 (目前不支持类别概率)。假定在 `DeviceMesh` 上复制。

  • weight (Union[torch.Tensor, DTensor], optional) – 如果提供,假定在 `DeviceMesh` 上复制。

  • label_smoothing – 目前不支持。

返回

一个复制的 `DTensor`。

示例

此处手动创建了一个分片 DTensor 以展示用法。实际上,它通常是 TP 模块的输出。

>>> from torch.distributed.tensor.parallel import loss_parallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> device_mesh = init_device_mesh("cuda", (8,))
>>> input = torch.randn(4, 16, device="cuda", requires_grad=True)
>>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)])
>>> target = torch.randint(16, (4,), device="cuda")
>>> with loss_parallel():
>>>     loss = F.cross_entropy(dist_input, target, reduction="mean")
>>>     loss.backward()
>>> ...

警告

The loss_parallel API is experimental and subject to change.