快捷方式

CompressedListStorage

class torchrl.data.replay_buffers.CompressedListStorage(max_size: int, *, compression_fn: Callable | None = None, decompression_fn: Callable | None = None, compression_level: int = 3, device: torch.device = 'cpu', compilable: bool = False)[源代码]

一个压缩和解压缩数据的存储。

此存储在存储时压缩数据,在检索时解压缩。它特别适用于存储可以被显著压缩以节省内存的原始感官观察(如图像)。

参数:
  • max_size (int) – 存储大小,即缓冲区中存储的最大元素数量。

  • compression_fn (callable, optional) – 用于压缩数据的函数。应接受一个张量并返回一个压缩后的字节张量。默认为 zstd 压缩。

  • decompression_fn (callable, optional) – 用于解压缩数据的函数。应接受一个压缩后的字节张量并返回原始张量。默认为 zstd 解压缩。

  • compression_level (int, optional) – 使用默认压缩函数时,压缩级别(zstd 为 1-22)。默认为 3。

  • device (torch.device, optional) – 存储和发送采样张量的设备。默认为 torch.device("cpu")

  • compilable (bool, optional) – 存储是否可编译。如果为 True,则写入器不能在多个进程之间共享。默认为 False

示例

>>> import torch
>>> from torchrl.data import CompressedListStorage, ReplayBuffer
>>> from tensordict import TensorDict
>>>
>>> # Create a compressed storage for image data
>>> storage = CompressedListStorage(max_size=1000, compression_level=3)
>>> rb = ReplayBuffer(storage=storage, batch_size=5)
>>>
>>> # Add some image data
>>> images = torch.randn(10, 3, 84, 84)  # Atari-like frames
>>> data = TensorDict({"obs": images}, batch_size=[10])
>>> rb.extend(data)
>>>
>>> # Sample and verify data is decompressed correctly
>>> sample = rb.sample(3)
>>> print(sample["obs"].shape)  # torch.Size([3, 3, 84, 84])
attach(buffer: Any) None

此函数将采样器附加到此存储。

从该存储读取的缓冲区必须通过调用此方法作为已附加实体包含进来。这确保了当存储中的数据发生变化时,组件能够感知到这些变化,即使该存储与其他缓冲区(例如,Priority Samplers)共享。

参数:

buffer – 读取此存储的对象。

bytes()[源代码]

返回存储中的字节数。

dump(*args, **kwargs)

dumps() 的别名。

load(*args, **kwargs)

loads() 的别名。

load_state_dict(state_dict: dict[str, Any]) None[源代码]

加载存储状态。

save(*args, **kwargs)

dumps() 的别名。

state_dict() dict[str, Any][源代码]

保存存储状态。

to_bytestream(data_to_bytestream: torch.Tensor | np.array | Any) bytes[源代码]

将数据转换为字节流。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源