快捷方式

LazyMemmapStorage

class torchrl.data.replay_buffers.LazyMemmapStorage(max_size: int, *, scratch_dir=None, device: device = 'cpu', ndim: int = 1, existsok: bool = False, compilable: bool = False)[源代码]

内存映射的张量和张量字典的存储。

参数:

max_size (int) – 存储大小,即缓冲区中存储的最大元素数量。

关键字参数:
  • scratch_dir (strpath) – 将写入 memmap-tensors 的目录。

  • device (torch.device, 可选) – 存储和发送采样张量的设备。默认为 torch.device("cpu")。如果提供 None,则设备将自动从传递的第一批数据中收集。此功能默认不启用,以避免意外将数据放置在 GPU 上,从而导致 OOM 问题。

  • ndim (int, 可选) – 在测量存储大小时要考虑的维度数量。例如,形状为 [3, 4] 的存储,如果 ndim=1,则容量为 3;如果 ndim=2,则容量为 12。默认为 1

  • existsok (bool, 可选) – 如果任何张量已存在于磁盘上,是否应引发错误。默认为 True。如果为 False,则张量将按原样打开,不会被覆盖。

注意

在检查点 LazyMemmapStorage 时,可以提供与存储已存储位置相同的路径,以避免执行已存储在磁盘上的数据的长时间复制。这仅在使用默认的 TensorStorageCheckpointer 检查点时才有效。示例

>>> from tensordict import TensorDict
>>> from torchrl.data import TensorStorage, LazyMemmapStorage, ReplayBuffer
>>> import tempfile
>>> from pathlib import Path
>>> import time
>>> td = TensorDict(a=0, b=1).expand(1000).clone()
>>> # We pass a path that is <main_ckpt_dir>/storage to LazyMemmapStorage
>>> rb_memmap = ReplayBuffer(storage=LazyMemmapStorage(10_000_000, scratch_dir="dump/storage"))
>>> rb_memmap.extend(td);
>>> # Checkpointing in `dump` is a zero-copy, as the data is already in `dump/storage`
>>> rb_memmap.dumps(Path("./dump"))

示例

>>> data = TensorDict({
...     "some data": torch.randn(10, 11),
...     ("some", "nested", "data"): torch.randn(10, 11, 12),
... }, batch_size=[10, 11])
>>> storage = LazyMemmapStorage(100)
>>> storage.set(range(10), data)
>>> len(storage)  # only the first dimension is considered as indexable
10
>>> storage.get(0)
TensorDict(
    fields={
        some data: MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
        some: TensorDict(
            fields={
                nested: TensorDict(
                    fields={
                        data: MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([11]),
                    device=cpu,
                    is_shared=False)},
            batch_size=torch.Size([11]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([11]),
    device=cpu,
    is_shared=False)

此类也支持 tensorclass 数据。

示例

>>> from tensordict import tensorclass
>>> @tensorclass
... class MyClass:
...     foo: torch.Tensor
...     bar: torch.Tensor
>>> data = MyClass(foo=torch.randn(10, 11), bar=torch.randn(10, 11, 12), batch_size=[10, 11])
>>> storage = LazyMemmapStorage(10)
>>> storage.set(range(10), data)
>>> storage.get(0)
MyClass(
    bar=MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
    foo=MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
    batch_size=torch.Size([11]),
    device=cpu,
    is_shared=False)
attach(buffer: Any) None

此函数将采样器附加到此存储。

读取此存储的缓冲区必须通过调用此方法作为附加实体包含。这保证了当存储中的数据发生更改时,组件能够感知到更改,即使存储与其他缓冲区共享(例如,优先级采样器)。

参数:

buffer – 读取此存储的对象。

dump(*args, **kwargs)

dumps() 的别名。

load(*args, **kwargs)

loads() 的别名。

save(*args, **kwargs)

dumps() 的别名。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源