torch.distributed.tensor#
创建于: Jun 13, 2025 | 最后更新于: Oct 28, 2025
注意
torch.distributed.tensor 目前处于 alpha 状态且正在开发中,我们致力于保持文档中所列大多数 API 的向后兼容性,但必要时可能会进行 API 更改。
PyTorch DTensor (分布式张量)#
PyTorch DTensor 提供了简单灵活的张量分片(sharding)原语,能够透明地处理分布式逻辑,包括跨设备/主机的分片存储、算子计算和集体通信。 DTensor 可用于构建各种并行化解决方案,并在处理多维分片时支持分片 state_dict 表示。
请参阅基于 DTensor 构建的 PyTorch 原生并行化解决方案的示例。
DTensor 遵循 SPMD (Single Program, Multiple Data) 编程模型,使用户能够像编写 **单设备程序一样编写分布式程序,并具有相同的收敛性**。它通过指定 DeviceMesh 和 Placement 来提供统一的张量分片布局(DTensor Layout)。
DeviceMesh使用 n 维数组表示集群的设备拓扑和通信器。Placement描述了逻辑张量在DeviceMesh上的分片布局。DTensor 支持三种类型的 placement:Shard、Replicate和Partial。
DTensor 类 API#
DTensor 是 torch.Tensor 的子类。这意味着一旦创建了 DTensor,它就可以以与 torch.Tensor 非常相似的方式使用,包括运行不同类型的 PyTorch 算子,就如同在单设备上运行一样,为 PyTorch 算子实现适当的分布式计算。
除了现有的 torch.Tensor 方法外,它还提供了一系列额外的方法来与 torch.Tensor 交互,例如将 DTensor 的布局 redistribute 到新的 DTensor,获取所有设备上的完整张量内容等。
- class torch.distributed.tensor.DTensor(**kwargs)#
DTensor(Distributed Tensor) 是torch.Tensor的一个子类,它提供了类似单设备的抽象,用于使用多设备torch.Tensor进行编程。它通过DeviceMesh和以下类型的Placement来描述分布式张量的分片布局(DTensor Layout)。Shard:张量在DeviceMesh的dim维度上分片到DeviceMesh的设备上。Replicate:张量在DeviceMesh的设备上复制。Partial:张量在DeviceMesh的设备上等待规约。
在调用 PyTorch 算子时,
DTensor会重写 PyTorch 算子,在必要时执行分片计算并发出通信。与算子计算一起,DTensor会正确地转换或传播 placement(DTensor Layout)(基于算子本身的语义),并生成新的DTensor输出。为了确保调用 PyTorch 算子时
DTensor分片计算的数值正确性,DTensor要求算子的每个张量参数都必须是 DTensor。注意
直接使用 Tensor 子类构造函数在这里不是创建
DTensor的推荐方式(即它不能正确处理 autograd,因此不是公共 API)。请参考 create_dtensor 部分,了解如何创建DTensor。- __create_chunk_list__()[source]#
返回一个 ChunkStorageMetadata 对象列表,它是一个数据类,描述了当前 rank 上本地分片/副本的大小/偏移量。对于 DTensor,每个 rank 将只有一个本地分片/副本,因此返回的列表通常只有一个元素。
此双下划线方法主要用于分布式 checkpoint 目的。
- 返回:
一个 List[
ChunkStorageMetadata] 对象,表示当前 rank 上的分片大小/偏移量。
- static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source]#
根据指定的
device_mesh和placements,在每个 rank 上从一个本地 torch.Tensor 创建一个DTensor。- 参数:
local_tensor (torch.Tensor) – 每个 rank 上的本地 torch.Tensor。
device_mesh (
DeviceMesh, optional) – 用于放置张量的 DeviceMesh。如果未指定,必须在 DeviceMesh 上下文管理器下调用,默认为 None。placements (List[
Placement], optional) – 描述如何将本地 torch.Tensor 放置在 DeviceMesh 上的 placement。必须具有与device_mesh.ndim相同的元素数量。
- 关键字参数:
run_check (bool, optional) – 以额外的通信为代价,在 rank 之间执行健全性检查,以检查每个本地张量的元信息以确保正确性。如果
placements中包含Replicate,则设备网格维度上的第一个 rank 的数据将被广播到其他 rank。默认为 False。shape (torch.Size, optional) – 一个 int 列表,指定基于 local_tensor 构建的 DTensor 的大小。注意,如果 rank 之间的
local_tensor的形状不同,则需要提供此参数。如果未提供,将假设给定的分布式张量在 rank 之间均匀分片,并据此计算shape。默认为 None。stride (tuple, optional) – 一个 int 列表,指定 DTensor 的步幅。如果未提供,将假设给定的分布式张量在 rank 之间均匀分片,并据此计算
stride。默认为 None。
- 返回:
一个
DTensor对象。- 返回类型:
注意
当
run_check=False时,用户有责任确保传入的本地张量在 rank 之间是正确的(即,张量是为Shard(dim)placement 分片的,或者是为Replicate()placement 复制的)。如果不是,则创建的 DTensor 的行为是未定义的。注意
from_local是可微分的,创建的 DTensor 对象的 requires_grad 将取决于 local_tensor 是否 requires_grad。
- full_tensor(*, grad_placements=None)[source]#
返回此 DTensor 的完整张量。它将执行必要的集体操作,以收集其 DeviceMesh 中的本地张量并将它们连接起来。这是以下代码的语法糖:
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()- 关键字参数:
grad_placements (List[
Placement], optional) – 描述从此函数返回的完整张量的任何梯度布局的未来布局的 placement。 full_tensor 将 DTensor 转换为完整的 torch.Tensor,返回的 torch.tensor 可能无法在后续代码中用作原始的复制 DTensor 布局。此参数是用户可以提供给 autograd 的提示,以防返回张量的梯度布局与原始复制 DTensor 布局不匹配。如果未指定,我们将假设完整张量的梯度布局为复制。- 返回:
一个
torch.Tensor对象,表示此 DTensor 的完整张量。- 返回类型:
注意
full_tensor是可微分的。
- redistribute(device_mesh=None, placements=None, *, async_op=False, forward_dtype=None, backward_dtype=None)[source]#
redistribute执行必要的集体操作,将当前 DTensor 从其当前 placement 重定向到新的 placement,或从其当前 DeviceMesh 重定向到新的 DeviceMesh。即,通过为 DeviceMesh 的每个维度指定 Replicate placement,可以将分片的 DTensor 转换为复制的 DTensor。当在一个设备网格维度上从当前 placement 重定向到新 placement 时,我们将执行以下操作,包括通信集体操作或本地操作:
Shard(dim)->Replicate():all_gatherShard(src_dim)->Shard(dst_dim):all_to_allReplicate()->Shard(dim): 本地分块(即torch.chunk)Partial()->Replicate():all_reducePartial()->Shard(dim):reduce_scatter
redistribute将能够正确地确定 DTensor 在 1D 或 N-D DeviceMesh 上所需的重定向步骤。- 参数:
device_mesh (
DeviceMesh, optional) – 用于放置 DTensor 的 DeviceMesh。如果未指定,将使用当前 DTensor 的 DeviceMesh。默认为 None。placements (List[
Placement], optional) – 描述如何将 DTensor 放置到 DeviceMesh 中的新 placement,必须具有与device_mesh.ndim相同的元素数量。默认为在所有 mesh 维度上复制。
- 关键字参数:
async_op (bool, optional) – 是否异步执行 DTensor 重定向操作。默认为 False。
forward_dtype (torch.dtype, optional) – 在其 forward 中重定向本地张量之前,可以将本地张量的数据类型转换为
forward_dtype。结果 DTensor 将是forward_dtype类型。默认为 None。backward_dtype (torch.dtype, optional) – 在其 backward 中重定向本地张量之前,可以将本地张量的数据类型转换为
backward_dtype。结果 DTensor 的梯度将转换回当前 DTensor 的数据类型。默认为 None。
- 返回:
一个
DTensor对象。- 返回类型:
注意
redistribute是可微分的,这意味着用户无需担心重定向操作的 backward 公式。注意
redistribute目前仅支持在相同的 DeviceMesh 上重定向 DTensor。如果您需要将 DTensor 重定向到不同的 DeviceMesh,请提交一个 issue。
- to_local(*, grad_placements=None)[source]#
获取当前 rank 上此 DTensor 的本地张量。对于分片,它返回逻辑张量视图的本地分片;对于复制,它返回当前 rank 上的副本。
- 关键字参数:
grad_placements (List[
Placement], optional) – 描述从此函数返回的张量的任何梯度布局的未来布局的 placement。 to_local 将 DTensor 转换为本地张量,返回的本地张量可能无法在后续代码中用作原始 DTensor 布局。此参数是用户可以提供给 autograd 的提示,以防返回张量的梯度布局与原始 DTensor 布局不匹配。如果未指定,我们将假设梯度布局与原始 DTensor 保持相同,并用于梯度计算。- 返回:
一个
torch.Tensor或AsyncCollectiveTensor对象。它表示其当前 rank 上的本地张量。当返回AsyncCollectiveTensor对象时,表示本地张量尚未就绪(即通信尚未完成)。在这种情况下,用户需要调用wait来等待本地张量就绪。- 返回类型:
注意
to_local是可微分的,返回的本地张量的requires_grad将取决于 DTensor 是否 requires_grad。
- property device_mesh: DeviceMesh#
与此 DTensor 对象关联的
DeviceMesh属性。注意
device_mesh是一个只读属性,无法设置。
DeviceMesh 作为分布式通信器#
DeviceMesh 是从 DTensor 构建的,用作描述集群设备拓扑并表示多维通信器(基于 ProcessGroup)的抽象。有关如何创建/使用 DeviceMesh 的详细信息,请参阅 DeviceMesh 教程。
DTensor Placement 类型#
DTensor 支持在每个 DeviceMesh 维度上的以下 Placement 类型:
- class torch.distributed.tensor.placement_types.Shard[source]#
Shard(dim)placement 描述了 DTensor 在其对应的 `DeviceMesh` 维度上的张量维度 `dim` 上的分片,其中 DeviceMesh 维度上的每个 rank 只持有全局张量的一个分片/部分。 `Shard(dim)` placement 遵循 `torch.chunk(dim)` 语义,当张量维度不能在 DeviceMesh 维度上整除时,DeviceMesh 维度上的最后几个分片可能为空。 `Shard` placement 可以被所有 DTensor API 使用(例如 distribute_tensor、from_local 等)。- 参数:
dim (int) – 描述 DTensor 在其对应的 DeviceMesh 维度上分片的张量维度。
警告
张量维度大小不能在 DeviceMesh 维度上整除的分片目前是实验性的,并且可能发生变化。
- class torch.distributed.tensor.placement_types.Replicate[source]#
Replicate()placement 描述了 DTensor 在其对应的 `DeviceMesh` 维度上的复制,其中 DeviceMesh 维度上的每个 rank 持有全局张量的副本。 `Replicate` placement 可以被所有 DTensor API 使用(例如 `distribute_tensor`、`DTensor.from_local` 等)。
- class torch.distributed.tensor.placement_types.Partial[source]#
Partial(reduce_op)placement 描述了在指定的 `DeviceMesh` 维度上等待规约的 DTensor,其中 DeviceMesh 维度上的每个 rank 持有全局张量的部分值。用户可以使用 `redistribute` 将 `Partial` DTensor 重定向到指定 `DeviceMesh` 维度上的 `Replicate` 或 `Shard(dim)` placement,这将触发底层必要的通信操作(例如 `allreduce`、`reduce_scatter`)。- 参数:
reduce_op (str, optional) – 用于 Partial DTensor 以生成 Replicated/Sharded DTensor 的规约操作。仅支持逐元素规约操作,包括:“sum”、“avg”、“product”、“max”、“min”,默认为“sum”。
注意
Partialplacement 可以作为 DTensor 操作的结果生成,并且只能由 `DTensor.from_local` API 使用。
- class torch.distributed.tensor.placement_types.MaskPartial(reduce_op=None, mask_buffer=None, offset_shape=None, offset_dim=0, *args, **kwargs)[source]#
为行式分片的嵌入操作设计的局部遮罩放置类型,在这种操作中,我们需要遮罩索引并将其调整到本地嵌入分片。嵌入遮罩是 Partial 放置类型的一种特殊形式。
注意:此 MaskPartial 放置类型的生命周期遵循相应的 DTensor 的生命周期,即 indices_mask 仅在 DTensor 的生命周期内有效。
- mask_buffer: MaskBuffer#
- class torch.distributed.tensor.placement_types.Placement#
Placement 类型的基类,它描述了一个 DTensor 如何放置在
DeviceMesh上。Placement和DeviceMesh一起可以描述 DTensor 的布局。它是三种主要的 DTensor Placement 类型:Shard、Replicate和Partial的基类。此类不应直接使用,主要用作类型提示。
- is_shard(self: torch._C._distributed.Placement, dim: SupportsInt | None = None) bool#
创建 DTensor 的不同方式
- 有三种方式可以构造一个
DTensor distribute_tensor()从每个 rank 上的逻辑或“全局”torch.Tensor创建一个DTensor。这可以用于分片叶子torch.Tensor(即模型参数/缓冲区和输入)。DTensor.from_local()从每个 rank 上的本地torch.Tensor创建一个DTensor,可用于从非叶子torch.Tensor(即前向/后向过程中的中间激活张量)创建DTensor。DTensor 提供了专门的张量工厂函数(例如
empty()、ones()、randn()等),允许通过直接指定DeviceMesh和Placement来创建不同的DTensor。与distribute_tensor()相比,这可以直接在设备上实现分片内存,而不是在初始化逻辑张量内存后执行分片。
从逻辑 torch.Tensor 创建 DTensor
torch.distributed 中的 SPMD(单程序多数据)编程模型通过(例如 torchrun)启动多个进程来执行相同的程序,这意味着程序内的模型将首先在不同的进程上初始化(即模型可能在 CPU、元设备上初始化,或者如果内存足够,则直接在 GPU 上初始化)。
DTensor 提供了一个 distribute_tensor() API,可以将模型权重或张量分片到 DTensor 中,它会从每个进程上的“逻辑”张量创建一个 DTensor。这将使创建的 DTensor 符合单设备语义,这对于**数值正确性**至关重要。
- torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None, *, src_data_rank=0)[source]#
根据指定的
placements将叶子torch.Tensor(例如 nn.Parameter/buffers)分发到device_mesh。device_mesh和placements的 rank 必须相同。要分发的tensor是逻辑或“全局”张量,API 将使用 DeviceMesh 维度上的第一个 rank 的tensor作为真相来源,以保持单设备语义。如果您想在 Autograd 计算的中间创建 DTensor,请改用DTensor.from_local()。- 参数:
tensor (torch.Tensor) – 要分发的 torch.Tensor。请注意,如果您想在一个不能被该 mesh 维度上的设备数量整除的维度上分片张量,我们将使用
torch.chunk语义来分片张量并散布分片。不均匀分片行为是实验性的,可能会发生变化。device_mesh (
DeviceMesh, optional) – 分发张量的 DeviceMesh,如果未指定,则必须在 DeviceMesh 上下文管理器下调用,默认为 Noneplacements (List[
Placement], optional) – 描述如何将张量放置在 DeviceMesh 上的 placements,必须具有与device_mesh.ndim相同的元素数量。如果未指定,我们将默认在 device_mesh 的每个维度的第一个 rank 上将张量复制到device_mesh。
- 关键字参数:
src_data_rank (int, optional) – 逻辑/全局张量的源数据的 rank,它被
distribute_tensor()用于将分片/副本分发/广播到其他 rank。默认情况下,我们在每个 DeviceMesh 维度上使用group_rank=0作为源数据,以保持单设备语义。如果显式传递None,distribute_tensor()将直接使用其本地数据,而不是尝试通过散布/广播来保持单设备语义。默认值:0- 返回:
一个
DTensor或XLAShardedTensor对象。- 返回类型:
注意
当使用
xla设备类型初始化 DeviceMesh 时,distribute_tensor返回 XLAShardedTensor。更多详情请参阅此问题。XLA 集成是实验性的,可能会发生变化。
除了 distribute_tensor() 之外,DTensor 还提供了一个 distribute_module() API,以便在 nn.Module 级别更容易进行分片。
- torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)[source]#
此函数公开三个函数来控制模块的参数/输入/输出
1. 通过指定
partition_fn在运行时执行之前对模块进行分片(即允许用户根据指定的 partition_fn 将 Module 参数转换为DTensor参数)。2. 在运行时执行期间通过指定input_fn和output_fn来控制模块的输入或输出。(即,将输入转换为DTensor,将输出转换回torch.Tensor)- 参数:
module (
nn.Module) – 用户待分区的模块。device_mesh (
DeviceMesh) – 放置模块的设备网格。partition_fn (Callable) – 分区参数的函数(即在
device_mesh上分片某些参数)。如果未指定partition_fn,则默认情况下我们将module的所有模块参数复制到网格中。input_fn (Callable) – 指定输入分布,即可以控制模块的输入如何分片。
input_fn将被安装为模块的forward_pre_hook(前向预钩子)。output_fn (Callable) – 指定输出分布,即可以控制输出如何分片,或将其转换回 torch.Tensor。
output_fn将被安装为模块的forward_hook(后向钩子)。
- 返回:
一个包含所有
DTensor参数/缓冲区的模块。- 返回类型:
注意
当使用
xla设备类型初始化 DeviceMesh 时,distribute_module返回带有 PyTorch/XLA SPMD 注释的参数的 nn.Module。更多详情请参阅此问题。XLA 集成是实验性的,可能会发生变化。
DTensor 工厂函数
DTensor 还提供了专门的张量工厂函数,允许直接使用类似 torch.Tensor 的工厂函数 API(例如,torch.ones、torch.empty 等)创建 DTensor,通过额外指定为 DTensor 创建的 DeviceMesh 和 Placement。
- torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]#
返回一个用标量值 0 填充的
DTensor。- 参数:
size (int...) – 定义输出
DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:zeros(1,2,3..) 或 zeros([1,2,3..]) 或 zeros((1,2,3..))- 关键字参数:
requires_grad (bool, optional) – 是否应自动梯度记录返回的
DTensor上的操作。默认为False。dtype (
torch.dtype, optional) – 返回的DTensor的所需数据类型。默认值:如果None,则使用全局默认值(请参阅torch.set_default_dtype())。layout (
torch.layout, optional) – 返回的DTensor的所需布局。默认值:torch.strided。device_mesh –
DeviceMesh类型,包含 rank 的网格信息placements – 一系列
Placement类型:Shard、Replicate
- 返回:
每个 rank 上的
DTensor对象- 返回类型:
- torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]#
返回一个用标量值 1 填充的
DTensor,其形状由可变参数size定义。- 参数:
size (int...) – 定义输出
DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 关键字参数:
dtype (
torch.dtype, optional) – 返回的DTensor的所需数据类型。默认值:如果None,则使用全局默认值(请参阅torch.set_default_dtype())。layout (
torch.layout, optional) – 返回的 DTensor 的所需布局。默认值:torch.strided。requires_grad (bool, optional) – 是否应自动梯度记录返回的
DTensor上的操作。默认为False。device_mesh –
DeviceMesh类型,包含 rank 的网格信息placements – 一系列
Placement类型:Shard、Replicate
- 返回:
每个 rank 上的
DTensor对象- 返回类型:
- torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]#
返回一个用未初始化数据填充的
DTensor。DTensor的形状由可变参数size定义。- 参数:
size (int...) – 定义输出
DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:empty(1,2,3..) 或 empty([1,2,3..]) 或 empty((1,2,3..))- 关键字参数:
dtype (
torch.dtype, optional) – 返回的DTensor的所需数据类型。默认值:如果None,则使用全局默认值(请参阅torch.set_default_dtype())。 layout (torch.layout, optional):返回的DTensor的所需布局。默认值:torch.strided。requires_grad (bool, optional) – 是否应自动梯度记录返回的
DTensor上的操作。默认为False。device_mesh –
DeviceMesh类型,包含 rank 的网格信息placements – 一系列
Placement类型:Shard、Replicate
- 返回:
每个 rank 上的
DTensor对象- 返回类型:
- torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]#
根据
device_mesh和placements,使用fill_value填充一个DTensor,其形状由参数size定义。- 参数:
- 关键字参数:
dtype (
torch.dtype, optional) – 返回的DTensor的所需数据类型。默认值:如果None,则使用全局默认值(请参阅torch.set_default_dtype())。layout (
torch.layout, optional) – 返回的 DTensor 的所需布局。默认值:torch.strided。requires_grad (bool, optional) – 是否应自动梯度记录返回的
DTensor上的操作。默认为False。device_mesh –
DeviceMesh类型,包含 rank 的网格信息。placements – 一系列
Placement类型:Shard、Replicate
- 返回:
每个 rank 上的
DTensor对象- 返回类型:
- torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]#
返回一个在区间
[0, 1)上均匀分布的随机数填充的DTensor。张量的形状由可变参数size定义。- 参数:
size (int...) – 定义输出
DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 关键字参数:
dtype (
torch.dtype, optional) – 返回的DTensor的所需数据类型。默认值:如果None,则使用全局默认值(请参阅torch.set_default_dtype())。layout (
torch.layout, optional) – 返回的 DTensor 的所需布局。默认值:torch.strided。requires_grad (bool, optional) – 是否应自动梯度记录返回的
DTensor上的操作。默认为False。device_mesh –
DeviceMesh类型,包含 rank 的网格信息。placements – 一系列
Placement类型:Shard、Replicate
- 返回:
每个 rank 上的
DTensor对象- 返回类型:
- torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[源码]#
返回一个
DTensor,其中填充了均值为 0、方差为 1 的正态分布随机数。张量的形状由可变参数size定义。- 参数:
size (int...) – 定义输出
DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 关键字参数:
dtype (
torch.dtype, optional) – 返回的DTensor的所需数据类型。默认值:如果None,则使用全局默认值(请参阅torch.set_default_dtype())。layout (
torch.layout, optional) – 返回的 DTensor 的所需布局。默认值:torch.strided。requires_grad (bool, optional) – 是否应自动梯度记录返回的
DTensor上的操作。默认为False。device_mesh –
DeviceMesh类型,包含 rank 的网格信息。placements – 一系列
Placement类型:Shard、Replicate
- 返回:
每个 rank 上的
DTensor对象- 返回类型:
随机操作#
DTensor 提供了分布式 RNG 功能,以确保分片张量上的随机操作获得唯一的值,而复制张量上的随机操作获得相同的值。该系统要求所有参与的 rank(例如 SPMD rank)在使用每个 dtensor 随机操作之前都使用相同的生成器状态开始,如果满足此条件,则确保在每个 dtensor 随机操作完成后,它们都处于相同的状态。在随机操作期间不执行通信来同步 RNG 状态。
接受 generator 关键字参数的操作将使用用户传入的生成器(如果已传入),否则使用设备的默认生成器。无论使用哪个生成器,它都将在 DTensor 操作之后推进。将相同的生成器用于 DTensor 和非 DTensor 操作是有效的,但必须谨慎确保非 DTensor 操作在所有 rank 上以相同的方式推进生成器状态。
当将 DTensor 与流水线并行一起使用时,每个流水线阶段的 rank 应使用不同的种子,而流水线阶段内的 rank 应使用相同的种子。
DTensor 的 RNG 基础设施基于 philox 基础的 RNG 算法,并支持任何 philox 基础的后端(cuda 和其他类 cuda 设备),但不幸的是,尚不支持 CPU 后端。
调试#
日志记录#
启动程序时,您可以使用 TORCH_LOGS 环境变量从 torch._logging 启用其他日志记录。
TORCH_LOGS=+dtensor将显示logging.DEBUG消息及其之上所有级别的消息。TORCH_LOGS=dtensor将显示logging.INFO消息及其之上所有级别的消息。TORCH_LOGS=-dtensor将显示logging.WARNING消息及其之上所有级别的消息。
调试工具#
为了调试应用了 DTensor 的程序,并了解底层发生的通信操作的更多详细信息,DTensor 提供了一个 CommDebugMode。
- class torch.distributed.tensor.debug.CommDebugMode#
CommDebugMode是一个上下文管理器,用于计算其上下文内的功能性通信操作的数量。它使用TorchDispatchMode来实现这一点。注意
并非所有通信操作都已支持。
使用示例
mod = ... comm_mode = CommDebugMode() with comm_mode: mod.sum().backward() print(comm_mode.get_comm_counts())
- generate_comm_debug_tracing_table(noise_level=3)[源码]#
生成详细表格,显示模块级别的操作和通信跟踪信息。信息量取决于 noise_level。
打印模块级别的通信计数。
打印未包含在琐碎操作中的 dTensor 操作,模块信息。
打印未包含在琐碎操作中的操作。
打印所有操作。
要可视化少于 3 个维度的 DTensor 的分片,DTensor 提供 visualize_sharding()。
实验性功能#
DTensor 还提供了一系列实验性功能。这些功能要么处于原型阶段,要么基本功能已完成但正在征求用户反馈。如果您对这些功能有任何反馈,请向 PyTorch 提交 issue。
- torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)[源码]#
context_parallel是一个用于启用上下文并行 (CP) 的实验性 API。此 API 执行两个操作:1) 使用启用 CP 的 SDPA(torch.nn.functional.scaled_dot_product_attention)进行补丁;2) 沿序列维度分片buffers,每个 rank 将根据mesh保留相应的分片。- 参数:
mesh (
DeviceMesh) – 用于上下文并行的设备网格。buffers (Optional[List[torch.Tensor]]) – 其使用依赖于序列维度的缓冲区。例如输入批次、标签和位置嵌入缓冲区。这些缓冲区必须沿序列维度分片以确保准确性。分片将就地进行,缓冲区在上下文内的形状会发生变化。上下文结束后,缓冲区将被恢复。
no_restore_buffers可用于指定哪些缓冲区不需要恢复。请注意,buffers不应包含任何 nn.Parameter。buffer_seq_dims (Optional[List[int]]) –
buffers的序列维度。no_restore_buffers (Optional[Set[torch.Tensor]]) – 这些集合中的缓冲区在上下文退出后不会被恢复。此集合必须是
buffers的子集。如果上下文退出后不再使用这些缓冲区,可以将它们放入此列表中以避免额外的恢复时间。
- 返回类型:
Generator[None, None, None]
警告
torch.distributed.tensor.experimental.context_parallel 是 PyTorch 中的一个原型功能。API 可能会发生变化。
- torch.distributed.tensor.experimental.local_map(func=None, out_placements=None, in_placements=None, in_grad_placements=None, device_mesh=None, *, redistribute_inputs=False)[源码]#
local_map()是一个实验性 API,它允许用户将DTensor传递给一个函数,该函数被编写为应用于torch.Tensor。这是通过提取DTensor的局部组件,调用函数,并根据out_placements将输出包装回DTensor来实现的。- 参数:
func (Callable) – 要应用于
DTensors 的每个局部分片的函数。out_placements (Union[PlacementType, Tuple[PlacementType, …]]) –
func的展平输出中DTensors 的期望放置。如果展平的output是单个值,则out_placements应为 PlacementType 类型。否则,如果展平的output有多个值,则out_placements应为 PlacementType 值的元组,与展平的output一对一映射。此外,对于Tensor输出,我们使用 PlacementType 作为其放置(Tuple[Placement] 值)。对于非 Tensor 输出,PlacementType 应为 None。请注意,唯一的例外是当未传入任何DTensor参数时。在这种情况下,即使 out_placements 不是 None,结果函数也应忽略期望的放置,因为该函数不是使用DTensors 运行的。in_placements (Tuple[PlacementType, …], optional) –
func的展平输入中与DTensors 梯度对应的期望放置。如果指定了in_placements,local_map()将检查每个DTensor参数的放置是否与期望的放置相同。如果放置不相同且redistribute_inputs为False,则会引发异常。否则,如果redistribute_inputs为True,则在将参数传递给func的局部张量之前,将首先对其进行重分布以达到期望的分片放置。唯一的例外是当期望的放置不是None且参数是torch.Tensor时。在这种情况下,将跳过放置检查,并将参数直接传递给func。如果in_placements为None,则不会执行放置检查。默认值:None。in_grad_placements (Tuple[PlacementType, …], optional) –
DTensors 梯度对应的放置提示,对应于展平的输入 DTensor。此参数是用户可以提供给to_local()的提示,以防局部张量输入的梯度布局与其DTensor输入布局不匹配。如果未指定,我们将假设局部张量输入的梯度布局与原始DTensor输入保持相同,并使用该布局进行梯度计算。默认值:None。device_mesh (
DeviceMesh, optional) – 输出DTensors 放置在其上的设备网格。如果未指定,将从第一个输入DTensor的设备网格推断。默认值:None。
- 关键字参数:
redistribute_inputs (bool, optional) – 指示是否在输入
DTensors 的放置与期望的输入放置不同时重新分片它们的布尔值。如果此值为False且某些DTensor输入具有不同的放置,则会引发异常。默认值:False。- 返回:
一个
Callable,它将func应用于输入DTensor的每个局部分片,并返回一个由func的返回值构造的DTensor。- 抛出:
AssertionError – 对于任何非 DTensor 输出,我们要求其在
out_placements中的相应输出放置为 None。如果不是这种情况,将引发 AssertionError。ValueError – 如果
redistribute_inputs=False但根据in_placements,输入DTensor需要重分布。
示例
>>> def mm_allreduce_forward(device_mesh, W, X): >>> partial_sum_tensor = torch.mm(W, X) >>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) >>> return reduced_tensor >>> >>> W = torch.randn(12, 8, requires_grad=False) >>> X = torch.randn(8, 16, requires_grad=False) >>> Y = torch.mm(W, X) >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor conversion >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], >>> in_placements=[col_wise, row_wise], >>> device_mesh=device_mesh, >>> ) >>> >>> W_dt = distribute_tensor( ... W, device_mesh, (col_wise) ... ) # col-wisely sharded W tensor >>> X_dt = distribute_tensor( ... X, device_mesh, (row_wise) ... ) # row-wisely sharded X tensor >>> Y_dt = local_mm_allreduce_forward( ... device_mesh, W_dt, X_dt ... ) # apply local_mm_allreduce_forward to DTensors
注意
此 API 目前是实验性的,可能会发生更改。
- torch.distributed.tensor.experimental.register_sharding(op)[源码]#
register_sharding()是一个实验性 API,它允许用户在张量输入和输出为 DTensor 时注册运算符的分片策略。当以下情况时,它可能很有用:(1)op没有默认的分片策略,例如当op是 DTensor 不支持的自定义运算符时;(2) 当用户想要覆盖现有运算符的默认分片策略时。- 参数:
op (Union[OpOverload, List[OpOverload]]) – 要注册自定义分片函数的运算符或运算符列表。
- 返回:
一个函数装饰器,可用于包装一个定义
op中指定的运算符的分片策略的函数。定义的が分片策略将注册到 DTensor,如果 DTensor 已实现该运算符,则会覆盖默认的分片策略。自定义分片函数接受与原始 op 相同的输入(除了如果参数是torch.Tensor,它将被 DTensor 内部使用的类张量对象替换)。该函数应返回一个 2 元组序列,每个元组指定可接受的输出放置及其相应的输入放置。
示例
>>> @register_sharding(aten._softmax.default) >>> def custom_softmax_sharding(x, dim, half_to_float): >>> softmax_dim = dim if dim >= 0 else dim + x.ndim >>> acceptable_shardings = [] >>> >>> all_replicate = ([Replicate()], [Replicate(), None, None]) >>> acceptable_shardings.append(all_replicate) >>> >>> for sharding_dim in range(x.ndim): >>> if sharding_dim != softmax_dim: >>> all_sharded = ( >>> [Shard(sharding_dim)], >>> [Shard(sharding_dim), None, None], >>> ) >>> acceptable_shardings.append(all_sharded) >>> >>> return acceptable_shardings
注意
此 API 目前是实验性的,可能会发生更改。
混合张量和 DTensor 操作#
因此,您收到了以下错误消息。
got mixed torch.Tensor and DTensor, need to convert all
torch.Tensor to DTensor before calling distributed operators!
有两种情况。
情况 1:这是用户错误#
遇到此错误的最常见方法是创建一个常规张量(使用工厂函数),然后执行张量-DTensor 操作,如下所示:
tensor = torch.arange(10)
return tensor + dtensor
我们不允许混合张量-DTensor 操作:如果任何操作(例如 torch.add)的输入是 DTensor,那么所有张量输入都必须是 DTensor。这是因为语义不明确。我们不知道 tensor 在不同 rank 之间是否相同,或者它是否不同,因此我们要求用户弄清楚如何从 tensor 构建具有准确放置的 DTensor。
如果每个 rank 都有相同的 tensor,那么请构造一个复制的 DTensor。
tensor = torch.arange(10)
tensor = DTensor.from_local(tensor, placements=(Replicate(),))
return tensor + dtensor
如果您想创建带有分片的 DTensor,请参考以下方法。语义上这意味着您的张量数据在分片之间分割,并且操作作用于“完整的堆叠数据”。
tensor = torch.full([], RANK)
tensor = DTensor.from_local(tensor, placements=(Shard(0),))
return tensor + dtensor
您可能还想对张量进行超出这些情况的其他操作(这些不是唯一的两种选择!)。
情况 2:错误来自 PyTorch 框架代码#
有时问题在于 PyTorch 框架代码尝试执行混合张量-DTensor 操作。这些是 PyTorch 中的错误,请提交 issue 以便我们修复它们。
在用户端,您能做的唯一事情就是避免使用导致问题的操作并提交错误报告。
对于 PyTorch 开发人员:一种修复方法是重写 PyTorch 框架代码以避免混合张量-DTensor 代码(如上一节所示)。
对于 PyTorch 开发人员:第二种方法是在 PyTorch 框架代码的正确位置启用 DTensor 隐式复制。启用后,任何混合张量-DTensor 操作都将假定非 DTensor 可以被复制。请谨慎使用此选项,因为它可能导致静默不正确的行为。