快捷方式

TensorDictMaxValueWriter

class torchrl.data.replay_buffers.TensorDictMaxValueWriter(rank_key=None, reduction: str = 'sum', **kwargs)[源码]

一个可组合回放缓冲区(composable replay buffer)的 Writer 类,它根据某个排名键(ranking key)保留顶部元素。

参数:
  • rank_key (strtuple of str) – 用于排名的键。默认为 ("next", "reward")

  • reduction (str) – 如果排名键有多个元素,则使用的归约方法。可以是 "max""min""mean""median""sum"

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter
>>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(1),
...     sampler=SamplerWithoutReplacement(),
...     batch_size=1,
...     writer=TensorDictMaxValueWriter(rank_key="key"),
... )
>>> td = TensorDict({
...     "key": torch.tensor(range(10)),
...     "obs": torch.tensor(range(10))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
9
>>> td = TensorDict({
...     "key": torch.tensor(range(10, 20)),
...     "obs": torch.tensor(range(10, 20))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
19
>>> td = TensorDict({
...     "key": torch.tensor(range(10)),
...     "obs": torch.tensor(range(10))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
19

注意

此类不兼容维度大于一的存储。这并不意味着禁止存储轨迹(trajectories),但存储的轨迹必须是逐个轨迹存储的。以下是一些该类有效和无效用法的示例。首先,一个用于存储单个转换(transitions)的扁平化缓冲区(flat buffer)。

>>> from torchrl.data import TensorStorage
>>> # Simplest use case: data comes in 1d and is stored as such
>>> data = TensorDict({
...     "obs": torch.zeros(10, 3),
...     "reward": torch.zeros(10, 1),
... }, batch_size=[10])
>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100),
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
>>> # We initialize the buffer: a total of 100 *transitions* can be stored
>>> rb.extend(data)
>>> # Samples 5 *transitions* at random
>>> sample = rb.sample(5)
>>> assert sample.shape == (5,)

其次,一个用于存储轨迹的缓冲区。最大信号在每个批次中聚合(例如,每个 rollouts 的奖励被求和)。

>>> # One can also store batches of data, each batch being a sub-trajectory
>>> env = ParallelEnv(2, lambda: GymEnv("Pendulum-v1"))
>>> # Get a batch of [2, 10] -- format is [Batch, Time]
>>> rollout = env.rollout(max_steps=10)
>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100),
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
>>> # We initialize the buffer: a total of 100 *trajectories* (!) can be stored
>>> rb.extend(rollout)
>>> # Sample 5 trajectories at random
>>> sample = rb.sample(5)
>>> assert sample.shape == (5, 10)

如果数据以批次形式传入,但需要一个扁平化缓冲区,我们可以简单地在扩展缓冲区之前将数据进行扁平化。

>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100),
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
>>> # We initialize the buffer: a total of 100 *transitions* can be stored
>>> rb.extend(rollout.reshape(-1))
>>> # Sample 5 trajectories at random
>>> sample = rb.sample(5)
>>> assert sample.shape == (5,)

无法创建沿时间维度扩展的缓冲区,这通常是使用带有批次轨迹的缓冲区的推荐方法。由于轨迹是重叠的,因此很难(如果不是不可能)聚合奖励值并对其进行比较。此构造函数无效(注意 ndim 参数)。

>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100, ndim=2),  # Breaks!
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
add(data: Any) int | torch.Tensor[源码]

在适当的索引处插入单个数据元素,并返回该索引。

传递给此模块的 `rank_key` 中的数据应结构化为 []。如果它有更多维度,它将被使用 `reduction` 方法归约(reduced)为单个值。

extend(data: TensorDictBase) None[源码]

在适当的索引处插入一系列数据点。

传递给此模块的 `rank_key` 中的数据应结构化为 [B]。如果它有更多维度,它将被使用 `reduction` 方法归约(reduced)为单个值。

get_insert_index(data: Any) int[源码]

返回数据应插入的索引,如果数据不应插入,则返回 `None`。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源