TensorDict 在分布式设置中¶
TensorDict 可以在分布式设置中使用,用于在节点之间传递张量。如果两个节点可以访问共享的物理存储,则可以使用内存映射张量(memory-mapped tensor)在运行中的进程之间高效地传递数据。在此,我们提供一些关于如何在分布式 RPC 设置中实现这一点的信息。有关分布式 RPC 的更多详细信息,请参阅 官方 pytorch 文档。
创建内存映射的 TensorDict¶
内存映射张量(和数组)的一个巨大优点是它们可以存储大量数据,并允许随时访问数据的切片,而无需将整个文件读入内存。TensorDict 在内存映射数组和 torch.Tensor
类之间提供了一个名为 MemmapTensor
的接口。MemmapTensor
实例可以存储在 TensorDict
对象中,从而允许 tensordict 表示存储在磁盘上的大数据集,并可在节点之间以批处理的方式轻松访问。
内存映射的 tensordict 可以通过以下方式创建:(1) 使用内存映射张量填充 TensorDict,或者 (2) 调用 tensordict.memmap_()
将其放入物理存储。可以通过查询 tensordict.is_memmap() 来轻松检查 tensordict 是否已放入物理存储。
创建内存映射张量本身有几种方法。首先,可以简单地创建一个空张量
>>> shape = torch.Size([3, 4, 5])
>>> tensor = Memmaptensor(*shape, prefix="/tmp")
>>> tensor[:2] = torch.randn(2, 4, 5)
prefix
属性指示临时文件存储的位置。至关重要的是,该张量必须存储在每个节点都可以访问的目录中!
另一种选择是将磁盘上的现有张量表示出来
>>> tensor = torch.randn(3)
>>> tensor = Memmaptensor(tensor, prefix="/tmp")
当张量很大或不适合放入内存时,将优先选择前一种方法:它适用于非常大的张量,并作为节点之间的公共存储。例如,可以创建一个数据集,以便单节点或不同节点都能轻松访问,比加载每个文件到内存中要快得多。
>>> dataset = TensorDict({
... "images": MemmapTensor(50000, 480, 480, 3),
... "masks": MemmapTensor(50000, 480, 480, 3, dtype=torch.bool),
... "labels": MemmapTensor(50000, 1, dtype=torch.uint8),
... }, batch_size=[50000], device="cpu")
>>> idx = [1, 5020, 34572, 11200]
>>> batch = dataset[idx].clone()
TensorDict(
fields={
images: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.float32),
labels: Tensor(torch.Size([4, 1]), dtype=torch.uint8),
masks: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.bool)},
batch_size=torch.Size([4]),
device=cpu,
is_shared=False)
请注意,我们已指定了 MemmapTensor
的设备。这种语法糖允许在需要时将查询到的张量直接加载到设备上。
需要考虑的另一个问题是,目前 MemmapTensor
与自动微分(autograd)操作不兼容。
跨节点操作内存映射张量¶
我们提供了一个简单的分布式脚本示例,其中一个进程创建一个内存映射张量,并将其引用发送给另一个负责更新它的工作进程。您可以在 benchmark 目录 中找到此示例。
简而言之,我们的目标是展示在节点可以访问共享物理存储时,如何处理大张量的读写操作。步骤包括:
在磁盘上创建空张量;
设置要执行的本地和远程操作;
使用 RPC 在工作进程之间传递命令,以读取和写入共享数据。
该示例首先编写一个函数,该函数使用填充了 1 的张量来更新特定索引处的 TensorDict 实例。
>>> def fill_tensordict(tensordict, idx):
... tensordict[idx] = TensorDict(
... {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5]
... )
... return tensordict
>>> fill_tensordict_cp = CloudpickleWrapper(fill_tensordict)
CloudpickleWrapper
确保该函数是可序列化的。接下来,我们创建一个相当大的 tensordict,以此说明如果必须通过常规的 tensorpipe 传递,它将很难从一个工作进程传递到另一个工作进程。
>>> tensordict = TensorDict(
... {"memmap": MemmapTensor(1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/")}, [1000]
... )
最后,仍然在主节点上,我们在 *远程节点* 上调用该函数,然后检查数据是否已写入所需位置。
>>> idx = [4, 5, 6, 7, 998]
>>> t0 = time.time()
>>> out = rpc.rpc_sync(
... worker_info,
... fill_tensordict_cp,
... args=(tensordict, idx),
... )
>>> print("time elapsed:", time.time() - t0)
>>> print("check all ones", out["memmap"][idx, :1, :1, :1].clone())
尽管调用 rpc.rpc_sync
涉及传递整个 tensordict,更新该对象的特定索引并将其返回给原始工作进程,但该代码片段的执行速度非常快(如果内存位置的引用已预先传递,速度会更快,请参阅 torchrl 的分布式回放缓冲区文档 了解更多信息)。
该脚本包含额外的 RPC 配置步骤,超出了本文档的范围。