评价此页

PyTorch 对称内存#

创建日期:2025 年 10 月 24 日 | 最后更新日期:2026 年 1 月 8 日

注意

torch.distributed._symmetric_memory 目前处于 Alpha 阶段,正在开发中。API 可能会发生更改。

为什么需要对称内存?#

随着并行化技术的飞速发展,现有的框架和库往往难以跟上步伐,开发者越来越依赖自定义实现来直接调度通信和计算。近年来,我们见证了从主要依赖一维数据并行技术转向多维并行技术的转变。后者对不同类型的通信具有不同的延迟要求,因此需要细粒度地重叠计算和通信。

为了最大限度地减少计算干扰,它们还需要使用拷贝引擎和网络接口卡 (NIC) 来驱动通信。远程直接内存访问 (RDMA) 等网络传输协议通过实现处理器和内存之间直接、高速、低延迟的通信来提高性能。这种多样性的增加表明需要比当前高级集合 API 更细粒度的通信原语,这些原语能够使开发人员实现针对其用例量身定制的特定算法,例如低延迟集合、细粒度计算-通信重叠或自定义融合。

此外,当今先进的 AI 系统通过高带宽链路(如 NVLinks、InfiniBand 或 RoCE)连接 GPU,使得 GPU 全局内存可以直接被对等方访问。这些连接为程序员提供了一个绝佳的机会,可以将系统编程为一个单一的、巨大的 GPU,拥有海量的可访问内存,而不是对单个“GPU 孤岛”进行编程。

在本教程中,我们将展示如何使用 PyTorch 对称内存将现代 GPU 系统编程为“单一 GPU”,并实现细粒度的远程访问。

PyTorch 对称内存带来了哪些新能力?#

PyTorch 对称内存解锁了三项新功能

  • 定制通信模式:增加了内核编写的灵活性,允许开发人员编写自定义内核来实现其自定义计算和通信,直接根据应用程序的需求进行量身定制。还可以轻松地添加对新数据类型的支持,以及这些数据类型可能需要的特殊计算,即使它们尚未在标准库中存在。

  • 内核内计算-通信融合:设备发起的通信功能允许开发人员编写同时包含计算和通信指令的内核,从而能够在尽可能小的粒度上融合计算和数据移动。

  • 低延迟远程访问:RDMA 等网络传输协议通过实现处理器和内存之间直接、高速、低延迟的通信,提高了对称内存在网络环境中的性能。RDMA 消除了与传统网络堆栈和 CPU 参与相关的开销。它还将数据传输从计算卸载到 NIC,从而释放计算资源用于计算任务。

接下来,我们将展示 PyTorch 对称内存 (SymmMem) 如何通过上述功能实现新应用。

一个“Hello World”示例#

PyTorch SymmMem 编程模型包含两个关键元素:

  • 创建对称张量

  • 创建 SymmMem 内核

要创建对称张量,可以使用 torch.distributed._symmetric_memory 包。

import torch.distributed._symmetric_memory as symm_mem

t = symm_mem.empty(128, device=torch.device("cuda", rank))
hdl = symm_mem.rendezvous(t, group)

函数 symm_mem.empty 创建一个由对称内存分配支持的张量。函数 rendezvous 与组中的对等方建立一次会合,并返回一个指向对称内存分配的句柄。该句柄提供了访问与对称内存分配相关的信息的方法,例如对等方 rank 上的对称缓冲区指针、多播指针(如果支持)和信号填充区。

函数 emptyrendezvous 必须在组中的所有 rank 上按相同的顺序调用。

然后,可以在这些张量上调用集合操作。例如,执行一次性 all-reduce:

# Most SymmMem ops are under the torch.ops.symm_mem namespace
torch.ops.symm_mem.one_shot_all_reduce(t, "sum", group)

请注意,torch.ops.symm_mem 是一个“op 命名空间”,而不是一个 Python 模块。因此,您无法通过 import torch.ops.symm_mem 来导入它,也不能通过 from torch.ops.symm_mem import one_shot_all_reduce 来导入一个 op。您可以直接调用 op,如上面的示例所示。

编写自己的内核#

要编写自己的与对称内存进行通信的内核,您需要访问映射的对等缓冲区地址以及进行同步所需的信号填充区。在内核中,您还需要执行正确的同步,以确保对等方已准备好进行通信,并向它们发出信号表明此 GPU 已准备就绪。

PyTorch 对称内存提供了与 CUDA Graph 兼容的同步原语,这些原语作用于每个对称内存分配附带的信号填充区。使用对称内存的内核可以用 CUDA 和 Triton 编写。以下是一个分配对称张量并交换句柄的示例:

import torch.distributed._symmetric_memory as symm_mem

dist.init_process_group()
rank = dist.get_rank()

# Allocate a tensor
t = symm_mem.empty(4096, device=f"cuda:{rank}")
# Establish symmetric memory and obtain the handle
hdl = symm_mem.rendezvous(t, dist.group.WORLD)

通过以下方式提供对缓冲区指针、多内存指针和信号填充区的访问:

hdl.buffer_ptrs
hdl.multicast_ptr
hdl.signal_pad_ptrs

buffer_ptrs 指向的数据可以像常规本地数据一样访问,并且任何必要的计算也可以按常规方式执行。与本地数据一样,您可以使用向量化访问来提高效率。

对称内存对于编写 Triton 内核特别方便。虽然以前 Triton 打破了编写高效 CUDA 代码的障碍,但现在可以将通信轻松添加到 Triton 内核中。下面的内核演示了一个用 Triton 编写的低延迟 all-reduce 内核。

@triton.jit
def one_shot_all_reduce_kernel(
    buf_tuple,
    signal_pad_ptrs,
    output_ptr,
    numel: tl.constexpr,
    rank: tl.constexpr,
    world_size: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    ptx_utils.symm_mem_sync(
        signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True
    )

    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE

    while block_start < numel:
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < numel
        acc = tl.zeros((BLOCK_SIZE,), dtype=tl.bfloat16)

        for i in tl.static_range(world_size):
            buffer_rank = buf_tuple[i]
            x = tl.load(buffer_rank + offsets, mask=mask)
            acc += x

        tl.store(output_ptr + offsets, acc, mask=mask)
        block_start += tl.num_programs(axis=0) * BLOCK_SIZE

    ptx_utils.symm_mem_sync(
        signal_pad_ptrs, None, rank, world_size, hasPreviousMemAccess=True
    )

内核顶部的同步保证了所有进程都看到一致的数据。内核的其余部分是可识别的 Triton 代码,Triton 会在后台对其进行优化,确保内存访问以向量化和展开的方式高效执行。与所有 Triton 内核一样,它很容易修改以添加额外的计算或更改通信算法。请访问 https://github.com/meta-pytorch/kraken/blob/main/kraken 查看使用对称内存实现 Triton 中常见模式的其他实用程序和示例。

扩展#

大型语言模型将专家分布在 8 个以上的 GPU 上,因此需要多节点访问能力。支持 RDMA 的 NIC 在此发挥作用。此外,NVSHMEM 或 rocSHMEM 等软件库通过比指针访问略高级的原语(如 put 和 get)来抽象化节点内访问和节点间访问之间的编程差异。

PyTorch 提供了 NVSHMEM 插件来增强 Triton 内核的跨节点能力。如以下代码片段所示,可以在内核中发起跨节点 put 命令。

import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem

@requires_nvshmem
@triton.jit
def my_put_kernel(
    dest,
    src,
    nelems,
    pe,
):
    nvshmem.put(dest, src, nelems, pe)

使用 requires_nvshmem 装饰器来指示内核需要 NVSHMEM 设备库作为外部依赖。当 Triton 编译内核时,装饰器将在您的系统路径中搜索 NVSHMEM 设备库。如果可用,Triton 将包含使用 NVSHMEM 函数所需的设备程序集。

使用内存池#

内存池允许 PyTorch SymmMem 缓存已经进行过会合的内存分配,从而在创建新张量时节省时间。为了方便起见,PyTorch SymmMem 添加了一个 get_mem_pool API 来返回一个对称内存池。用户可以使用返回的 MemPool 和 torch.cuda.use_mem_pool 上下文管理器。在下面的示例中,张量 x 将从对称内存创建。

    import torch.distributed._symmetric_memory as symm_mem

    mempool = symm_mem.get_mem_pool(device)

    with torch.cuda.use_mem_pool(mempool):
        x = torch.arange(128, device=device)

    torch.ops.symm_mem.one_shot_all_reduce(x, "sum", group_name)

同样,您也可以在 MemPool 上下文下进行计算操作,结果张量也将从对称内存创建。

    dim = 1024
    w = torch.ones(dim, dim, device=device)
    x = torch.ones(1, dim, device=device)

    mempool = symm_mem.get_mem_pool(device)
    with torch.cuda.use_mem_pool(mempool):
        # y will be in symmetric memory
        y = torch.mm(x, w)

截至 torch 2.11,CUDANVSHMEM 后端支持 MemPool。MemPool 对 NCCL 后端的支持正在进行中。

API参考#

torch.distributed._symmetric_memory.empty(*size: _int, dtype: _dtype | None = None, device: _device | None = None) Tensor[源代码]#
torch.distributed._symmetric_memory.empty(size: Sequence[_int], *, dtype: _dtype | None = None, device: _device | None = None) Tensor

类似于 torch.empty()。返回的张量可用于 torch._distributed._symmetric_memory.rendezvous() 以在参与进程之间建立对称张量。

参数:

size (int...) – 定义输出张量形状的整数序列。可以是可变数量的参数,也可以是列表或元组之类的集合。

关键字参数:
  • dtype (torch.dtype, optional) – 返回张量所需的数据类型。默认为:如果 None,则使用全局默认值(请参阅 torch.set_default_dtype())。

  • device (torch.device, optional) – 返回张量所需的设备。默认为:如果 None,则使用当前设备的默认张量类型(请参阅 torch.set_default_device())。对于 CPU 张量类型,device 将是 CPU,对于 CUDA 张量类型,则是当前 CUDA 设备。

torch.distributed._symmetric_memory.rendezvous(tensor, group) _SymmetricMemory[源代码]#

在参与进程之间建立对称张量。这是一个集体操作。

参数:
  • tensor (torch.Tensor) – 用于建立对称张量的本地张量。它必须通过 torch._distributed._symmetric_memory.empty() 分配。所有参与进程的形状、dtype 和设备类型必须相同。

  • group (Union[str, torch.distributed.ProcessGroup]) – 标识参与进程的组。它可以是组名或进程组对象。

返回类型:

_SymmetricMemory

torch.distributed._symmetric_memory.is_nvshmem_available() bool[源代码]#

检查当前构建和当前系统是否可用 NVSHMEM。

返回类型:

布尔值

torch.distributed._symmetric_memory.set_backend(name)[源代码]#

设置对称内存分配的后端。这是一个全局设置,会影响所有后续调用 torch._distributed._symmetric_memory.empty() 的操作。请注意,一旦分配了对称内存张量,就无法更改后端。

参数:

backend (str) – 对称内存分配的后端。当前仅支持 “NVSHMEM”“CUDA”“NCCL”

torch.distributed._symmetric_memory.get_backend(device)[源代码]#

获取给定设备的对称内存分配的后端。如果未找到,则返回 None。

参数:

device (torch.device 或 str) – 要获取后端的设备。

返回类型:

str | None

torch.distributed._symmetric_memory.get_mem_pool(device)[源代码]#

获取给定设备的对称内存池。如果未找到,则创建一个新池。

此池的张量分配必须在 ranks 之间是对称的。分配的张量可与对称操作一起使用,例如,在 torch.ops.symm_mem 下定义的那些操作。

参数:

device (torch.device 或 str) – 要获取对称内存池的设备。

返回:

给定设备的对称内存池。

返回类型:

torch.cuda.MemPool

示例

>>> pool = torch.distributed._symmetric_memory.get_mem_pool("cuda:0")
>>> with torch.cuda.use_mem_pool(pool):
>>>     tensor = torch.randn(1000, device="cuda:0")
>>> tensor = torch.ops.symm_mem.one_shot_all_reduce(tensor, "sum", group_name)

Op 参考#

注意

以下 Op 托管在 torch.ops.symm_mem 命名空间下。您可以直接通过 torch.ops.symm_mem.<op_name> 调用它们。

torch.ops.symm_mem.multimem_all_reduce_(input: Tensor, reduce_op: str, group_name: str) Tensor#

对输入张量执行 multimem all-reduce 操作。此操作需要硬件支持 multimem 操作。在 NVIDIA GPU 上,需要 NVLink SHARP。

参数:
  • input (Tensor) – 用于执行 all-reduce 的输入张量。必须是对称的。

  • reduce_op (str) – 要执行的归约操作。目前仅支持“sum”。

  • group_name (str) – 要执行 all-reduce 的组名。

torch.ops.symm_mem.multimem_all_gather_out(input: Tensor, group_name: str, out: Tensor) Tensor#

对输入张量执行 multimem all-gather 操作。此操作需要硬件支持 multimem 操作。在 NVIDIA GPU 上,需要 NVLink SHARP。

参数:
  • input (Tensor) – 用于执行 all-gather 的输入张量。

  • group_name (str) – 要执行 all-gather 的组名。

  • out (Tensor) – 用于存储 all-gather 操作结果的输出张量。必须是对称的。

torch.ops.symm_mem.one_shot_all_reduce(input: Tensor, reduce_op: str, group_name: str) Tensor#

对输入张量执行一次性 all-reduce 操作。

参数:
  • input (Tensor) – 用于执行 all-reduce 的输入张量。必须是对称的。

  • reduce_op (str) – 要执行的归约操作。目前仅支持“sum”。

  • group_name (str) – 要执行 all-reduce 的组名。

torch.ops.symm_mem.one_shot_all_reduce_out(input: Tensor, reduce_op: str, group_name: str, out: Tensor) Tensor#

根据输入张量执行一次性 all-reduce 操作,并将结果写入输出张量。

参数:
  • input (Tensor) – 用于执行 all-reduce 的输入张量。必须是对称的。

  • reduce_op (str) – 要执行的归约操作。目前仅支持“sum”。

  • group_name (str) – 要执行 all-reduce 的组名。

  • out (Tensor) – 用于存储 all-reduce 操作结果的输出张量。可以是一个常规张量。

torch.ops.symm_mem.two_shot_all_reduce_(input: Tensor, reduce_op: str, group_name: str) Tensor#

对输入张量执行两步 all-reduce 操作。

参数:
  • input (Tensor) – 用于执行 all-reduce 的输入张量。必须是对称的。

  • reduce_op (str) – 要执行的归约操作。目前仅支持“sum”。

  • group_name (str) – 要执行 all-reduce 的组名。

torch.ops.symm_mem.all_to_all_vdev(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str) None#

使用 NVSHMEM 执行 all-to-all-v 操作,并提供设备上的拆分信息。

参数:
  • input (Tensor) – 用于执行 all-to-all 的输入张量。必须是对称的。

  • out (Tensor) – 用于存储 all-to-all 操作结果的输出张量。必须是对称的。

  • in_splits (Tensor) – 包含要发送到每个对等方的数据拆分的张量。必须是对称的。必须是大小为 (group_size,) 的张量。拆分以第一维度中的元素为单位。

  • out_splits_offsets (Tensor) – 包含从每个对等方接收的数据的分割和偏移量的张量。必须是对称的。大小必须为 (2, group_size)。行(按顺序)为:输出分割和输出偏移量。

  • group_name (str) – 进行 all-to-all 操作的组名。

torch.ops.symm_mem.all_to_all_vdev_2d(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str[, major_align: int = None]) None#

使用 NVSHMEM 执行 2D all-to-all-v 操作,并在设备上提供分割信息。在混合专家模型中,此操作可用于分派 token。

参数:
  • input (Tensor) – 用于执行 all-to-all 的输入张量。必须是对称的。

  • out (Tensor) – 用于存储 all-to-all 操作结果的输出张量。必须是对称的。

  • in_splits (Tensor) – 包含发送到每个专家的数据的分割的张量。必须是对称的。大小必须为 (group_size * ne,),其中 ne 是每个 rank 的专家数量。分割的单位是第一个维度的元素。

  • out_splits_offsets (Tensor) – 包含从每个对等方接收的数据的分割和偏移量的张量。必须是对称的。大小必须为 (2, group_size * ne)。行(按顺序)为:输出分割和输出偏移量。

  • group_name (str) – 进行 all-to-all 操作的组名。

  • major_align (int) – 每个专家的输出块的主维度可选对齐。如果未提供,则假定对齐为 1。任何对齐调整都将反映在输出偏移量中。

下面说明了 2D AllToAllv shuffle:(world_size = 2, ne = 2, 专家总数 = 4)

Source: |       Rank 0      |       Rank 1      |
        | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |

Dest  : |       Rank 0      |       Rank 1      |
        | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |

其中每个 c_i / d_iinput 张量的切片,目标是专家 i,长度由输入分割指示。也就是说,2D AllToAllv shuffle 实现从输入的 rank-major 顺序到输出的 expert-major 顺序的转置。

如果 major_align 不是 1,则 c1、c2、c3 的输出偏移量将向上对齐到此值。例如,如果 c0 的长度为 5,d0 的长度为 7(总共 12),并且 major_align 设置为 16,则 c1 的输出偏移量将为 16。c2 和 c3 类似。此值对次要维度的偏移量(即 d0、d1、d2 和 d3)没有影响。注意:由于 cutlass 不支持空 bin,如果对齐长度为 0,我们将其设置为 major_align。请参阅 pytorch/pytorch#152668

torch.ops.symm_mem.all_to_all_vdev_2d_offset(Tensor input, Tensor out, Tensor in_splits_offsets, Tensor out_splits_offsets, str group_name) None#

执行 2D AllToAllv shuffle 操作,在设备上提供输入分割和偏移量信息。输入偏移量不必是输入分割的精确前缀和,即分割块之间允许填充。但是,这些填充不会传输到对等 rank。

在混合专家模型中,此操作可用于合并在并行 rank 上由专家处理的 token。此操作可视为 all_to_all_vdev_2d 操作(将 token shuffle 到专家)的“反向”操作。

参数:
  • input (Tensor) – 用于执行 all-to-all 的输入张量。必须是对称的。

  • out (Tensor) – 用于存储 all-to-all 操作结果的输出张量。必须是对称的。

  • in_splits_offsets (Tensor) – 包含发送到每个专家的分割和偏移量张量。必须是对称的。大小必须为 (2, group_size * ne),其中 ne 是专家数量。行(按顺序)为:输入分割和输入偏移量。分割的单位是第一个维度的元素。

  • out_splits_offsets (Tensor) – 包含从每个对等方接收的数据的分割和偏移量的张量。必须是对称的。大小必须为 (2, group_size * ne)。行(按顺序)为:输出分割和输出偏移量。

  • group_name (str) – 进行 all-to-all 操作的组名。

torch.ops.symm_mem.tile_reduce(in_tile: Tensor, out_tile: Tensor, root: int, group_name: str[, reduce_op: str = 'sum']) None#

将进程组内所有 rank 的 2D tile 减少到指定的根 rank。

参数:
  • in_tile (Tensor) – 要减少的输入 2D 张量。必须是对称分配的。

  • out_tile (Tensor) – 包含减少结果的输出 2D 张量。必须是对称的,并且具有与 in_tile 相同的形状、dtype 和设备。

  • root (int) – 指定组中接收减少结果的进程的 rank。

  • group_name (str) – 要在此执行减少操作的对称内存进程的名称。

  • reduce_op (str) – 要执行的减少操作。目前只支持 "sum"。默认为 "sum"

此函数将组中所有成员的 in_tile 张量进行减少,并将结果写入根 rank 的 out_tile。所有 rank 都必须参与并提供相同的 group_name 和张量形状。

示例

>>> 
>>> # Reduce the bottom-right quadrant of a tensor
>>> tile_size = full_size // 2
>>> full_inp = symm_mem.empty(full_size, full_size)
>>> full_out = symm_mem.empty(full_size, full_size)
>>> s = slice(tile_size, 2 * tile_size)
>>> in_tile = full_inp[s, s]
>>> out_tile = full_out[s, s]
>>> torch.ops.symm_mem.tile_reduce(in_tile, out_tile, root=0, group_name)
torch.ops.symm_mem.multi_root_tile_reduce(in_tiles: list[Tensor], out_tile: Tensor, roots: list[int], group_name: str, [reduce_op: str = 'sum']) None#

并发执行多个 tile 减少,每个 tile 减少到单独的根。

:param list[Tensor] in_tiles: 输入张量列表。 :param Tensor out_tile: 包含被减少 tile 的输出张量。 :param list[int] roots: 根 rank 列表,每个根 rank 对应 in_tiles 中的一个输入 tile,顺序相同。一个 rank 不能作为根两次。 :param str group_name: 用于集合操作的组名。 :param str reduce_op: 要执行的减少操作。目前只支持“sum”。

示例

>>> 
>>> # Reduce four quadrants of a tensor, each to a different root
>>> tile_size = full_size // 2
>>> full_inp = symm_mem.empty(full_size, full_size)
>>> s0 = slice(0, tile_size)
>>> s1 = slice(tile_size, 2 * tile_size)
>>> in_tiles = [ full_inp[s0, s0], full_inp[s0, s1], full_inp[s1, s0], full_inp[s1, s1] ]
>>> out_tile = symm_mem.empty(tile_size, tile_size)
>>> roots = [0, 1, 2, 3]
>>> torch.ops.symm_mem.multi_root_tile_reduce(in_tiles, out_tile, roots, group_name)