• 文档 >
  • 在分布式设置中使用 TensorDict
快捷方式

TensorDict 在分布式设置中

TensorDict 可以在分布式设置中使用,用于在节点之间传递张量。如果两个节点可以访问共享的物理存储,则可以使用内存映射张量(memory-mapped tensor)在运行中的进程之间高效地传递数据。在此,我们提供一些关于如何在分布式 RPC 设置中实现这一点的信息。有关分布式 RPC 的更多详细信息,请参阅 官方 pytorch 文档

创建内存映射的 TensorDict

内存映射张量(和数组)的一个巨大优点是它们可以存储大量数据,并允许随时访问数据的切片,而无需将整个文件读入内存。TensorDict 在内存映射数组和 torch.Tensor 类之间提供了一个名为 MemmapTensor 的接口。MemmapTensor 实例可以存储在 TensorDict 对象中,从而允许 tensordict 表示存储在磁盘上的大数据集,并可在节点之间以批处理的方式轻松访问。

内存映射的 tensordict 可以通过以下方式创建:(1) 使用内存映射张量填充 TensorDict,或者 (2) 调用 tensordict.memmap_() 将其放入物理存储。可以通过查询 tensordict.is_memmap() 来轻松检查 tensordict 是否已放入物理存储。

创建内存映射张量本身有几种方法。首先,可以简单地创建一个空张量

>>> shape = torch.Size([3, 4, 5])
>>> tensor = Memmaptensor(*shape, prefix="/tmp")
>>> tensor[:2] = torch.randn(2, 4, 5)

prefix 属性指示临时文件存储的位置。至关重要的是,该张量必须存储在每个节点都可以访问的目录中!

另一种选择是将磁盘上的现有张量表示出来

>>> tensor = torch.randn(3)
>>> tensor = Memmaptensor(tensor, prefix="/tmp")

当张量很大或不适合放入内存时,将优先选择前一种方法:它适用于非常大的张量,并作为节点之间的公共存储。例如,可以创建一个数据集,以便单节点或不同节点都能轻松访问,比加载每个文件到内存中要快得多。

在磁盘上创建空数据集
>>> dataset = TensorDict({
...      "images": MemmapTensor(50000, 480, 480, 3),
...      "masks": MemmapTensor(50000, 480, 480, 3, dtype=torch.bool),
...      "labels": MemmapTensor(50000, 1, dtype=torch.uint8),
... }, batch_size=[50000], device="cpu")
>>> idx = [1, 5020, 34572, 11200]
>>> batch = dataset[idx].clone()
TensorDict(
    fields={
        images: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.float32),
        labels: Tensor(torch.Size([4, 1]), dtype=torch.uint8),
        masks: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.bool)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)

请注意,我们已指定了 MemmapTensor 的设备。这种语法糖允许在需要时将查询到的张量直接加载到设备上。

需要考虑的另一个问题是,目前 MemmapTensor 与自动微分(autograd)操作不兼容。

跨节点操作内存映射张量

我们提供了一个简单的分布式脚本示例,其中一个进程创建一个内存映射张量,并将其引用发送给另一个负责更新它的工作进程。您可以在 benchmark 目录 中找到此示例。

简而言之,我们的目标是展示在节点可以访问共享物理存储时,如何处理大张量的读写操作。步骤包括:

  • 在磁盘上创建空张量;

  • 设置要执行的本地和远程操作;

  • 使用 RPC 在工作进程之间传递命令,以读取和写入共享数据。

该示例首先编写一个函数,该函数使用填充了 1 的张量来更新特定索引处的 TensorDict 实例。

>>> def fill_tensordict(tensordict, idx):
...     tensordict[idx] = TensorDict(
...         {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5]
...     )
...     return tensordict
>>> fill_tensordict_cp = CloudpickleWrapper(fill_tensordict)

CloudpickleWrapper 确保该函数是可序列化的。接下来,我们创建一个相当大的 tensordict,以此说明如果必须通过常规的 tensorpipe 传递,它将很难从一个工作进程传递到另一个工作进程。

>>> tensordict = TensorDict(
...     {"memmap": MemmapTensor(1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/")}, [1000]
... )

最后,仍然在主节点上,我们在 *远程节点* 上调用该函数,然后检查数据是否已写入所需位置。

>>> idx = [4, 5, 6, 7, 998]
>>> t0 = time.time()
>>> out = rpc.rpc_sync(
...     worker_info,
...     fill_tensordict_cp,
...     args=(tensordict, idx),
... )
>>> print("time elapsed:", time.time() - t0)
>>> print("check all ones", out["memmap"][idx, :1, :1, :1].clone())

尽管调用 rpc.rpc_sync 涉及传递整个 tensordict,更新该对象的特定索引并将其返回给原始工作进程,但该代码片段的执行速度非常快(如果内存位置的引用已预先传递,速度会更快,请参阅 torchrl 的分布式回放缓冲区文档 了解更多信息)。

该脚本包含额外的 RPC 配置步骤,超出了本文档的范围。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源