TensorStorage¶
- class torchrl.data.replay_buffers.TensorStorage(storage, max_size=None, *, device: device = 'cpu', ndim: int = 1, compilable: bool = False)[源代码]¶
用于存储张量和张量字典的存储。
- 参数:
storage (tensor or TensorDict) – 要使用的数据缓冲区。
max_size (int) – 存储大小,即缓冲区中存储的最大元素数量。
- 关键字参数:
device (torch.device, optional) – 采样张量将被存储和发送的设备。默认为
torch.device("cpu")
。如果传入“auto”,则设备将从传入的第一个批次数据自动收集。默认情况下不启用此功能,以避免数据意外放置在 GPU 上,从而导致 OOM 问题。ndim (int, optional) – 在衡量存储大小时要考虑的维度数。例如,形状为
[3, 4]
的存储,如果ndim=1
,容量为3
;如果ndim=2
,容量为12
。默认为1
。compilable (bool, optional) – 存储是否可编译。如果为
True
,则写入器不能在多个进程之间共享。默认为False
。
示例
>>> data = TensorDict({ ... "some data": torch.randn(10, 11), ... ("some", "nested", "data"): torch.randn(10, 11, 12), ... }, batch_size=[10, 11]) >>> storage = TensorStorage(data) >>> len(storage) # only the first dimension is considered as indexable 10 >>> storage.get(0) TensorDict( fields={ some data: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), some: TensorDict( fields={ nested: TensorDict( fields={ data: Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([11]), device=None, is_shared=False)}, batch_size=torch.Size([11]), device=None, is_shared=False)}, batch_size=torch.Size([11]), device=None, is_shared=False) >>> storage.set(0, storage.get(0).zero_()) # zeros the data along index ``0``
此类也支持 tensorclass 数据。
示例
>>> from tensordict import tensorclass >>> @tensorclass ... class MyClass: ... foo: torch.Tensor ... bar: torch.Tensor >>> data = MyClass(foo=torch.randn(10, 11), bar=torch.randn(10, 11, 12), batch_size=[10, 11]) >>> storage = TensorStorage(data) >>> storage.get(0) MyClass( bar=Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False), foo=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), batch_size=torch.Size([11]), device=None, is_shared=False)
- attach(buffer: Any) None ¶
此函数将采样器附加到此存储。
从该存储读取的缓冲区必须通过调用此方法作为附加实体包含。这可以确保当存储中的数据发生变化时,组件能够感知到这些变化,即使存储与其他缓冲区(例如优先级采样器)共享。
- 参数:
buffer – 读取此存储的对象。
- dump(*args, **kwargs)¶
dumps()
的别名。
- load(*args, **kwargs)¶
loads()
的别名。
- save(*args, **kwargs)¶
dumps()
的别名。