CompressedListStorage¶
- class torchrl.data.replay_buffers.CompressedListStorage(max_size: int, *, compression_fn: Callable | None = None, decompression_fn: Callable | None = None, compression_level: int = 3, device: torch.device = 'cpu', compilable: bool = False)[源代码]¶
一个压缩和解压缩数据的存储。
此存储在存储时压缩数据,在检索时解压缩。它特别适用于存储可以被显著压缩以节省内存的原始感官观察(如图像)。
- 参数:
max_size (int) – 存储大小,即缓冲区中存储的最大元素数量。
compression_fn (callable, optional) – 用于压缩数据的函数。应接受一个张量并返回一个压缩后的字节张量。默认为 zstd 压缩。
decompression_fn (callable, optional) – 用于解压缩数据的函数。应接受一个压缩后的字节张量并返回原始张量。默认为 zstd 解压缩。
compression_level (int, optional) – 使用默认压缩函数时,压缩级别(zstd 为 1-22)。默认为 3。
device (torch.device, optional) – 存储和发送采样张量的设备。默认为
torch.device("cpu")
。compilable (bool, optional) – 存储是否可编译。如果为
True
,则写入器不能在多个进程之间共享。默认为False
。
示例
>>> import torch >>> from torchrl.data import CompressedListStorage, ReplayBuffer >>> from tensordict import TensorDict >>> >>> # Create a compressed storage for image data >>> storage = CompressedListStorage(max_size=1000, compression_level=3) >>> rb = ReplayBuffer(storage=storage, batch_size=5) >>> >>> # Add some image data >>> images = torch.randn(10, 3, 84, 84) # Atari-like frames >>> data = TensorDict({"obs": images}, batch_size=[10]) >>> rb.extend(data) >>> >>> # Sample and verify data is decompressed correctly >>> sample = rb.sample(3) >>> print(sample["obs"].shape) # torch.Size([3, 3, 84, 84])
- attach(buffer: Any) None ¶
此函数将采样器附加到此存储。
从该存储读取的缓冲区必须通过调用此方法作为已附加实体包含进来。这确保了当存储中的数据发生变化时,组件能够感知到这些变化,即使该存储与其他缓冲区(例如,Priority Samplers)共享。
- 参数:
buffer – 读取此存储的对象。
- dump(*args, **kwargs)¶
dumps()
的别名。
- load(*args, **kwargs)¶
loads()
的别名。
- save(*args, **kwargs)¶
dumps()
的别名。
- to_bytestream(data_to_bytestream: torch.Tensor | np.array | Any) bytes [源代码]¶
将数据转换为字节流。