TensorDictMaxValueWriter¶
- class torchrl.data.replay_buffers.TensorDictMaxValueWriter(rank_key=None, reduction: str = 'sum', **kwargs)[source]¶
一个 Writer 类,用于可组合的重放缓冲区,根据某个排名键保留顶级元素。
- 参数:
rank_key (str 或 tuple of str) – 用于排名的键。默认为
("next", "reward")
。reduction (str) – 如果排名键有多个元素,则使用的归约方法。可以是
"max"
、"min"
、"mean"
、"median"
或"sum"
。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter >>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(1), ... sampler=SamplerWithoutReplacement(), ... batch_size=1, ... writer=TensorDictMaxValueWriter(rank_key="key"), ... ) >>> td = TensorDict({ ... "key": torch.tensor(range(10)), ... "obs": torch.tensor(range(10)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 9 >>> td = TensorDict({ ... "key": torch.tensor(range(10, 20)), ... "obs": torch.tensor(range(10, 20)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 19 >>> td = TensorDict({ ... "key": torch.tensor(range(10)), ... "obs": torch.tensor(range(10)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 19
注意
此类与维度大于一的存储不兼容。这并不意味着禁止存储轨迹,而是意味着存储的轨迹必须按轨迹存储。以下是该类有效和无效用法的示例。首先,一个存储单个转换的扁平缓冲区
>>> from torchrl.data import TensorStorage >>> # Simplest use case: data comes in 1d and is stored as such >>> data = TensorDict({ ... "obs": torch.zeros(10, 3), ... "reward": torch.zeros(10, 1), ... }, batch_size=[10]) >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *transitions* can be stored >>> rb.extend(data) >>> # Samples 5 *transitions* at random >>> sample = rb.sample(5) >>> assert sample.shape == (5,)
其次,一个存储轨迹的缓冲区。最大信号在每个批次中聚合(例如,每个 rollout 的奖励被求和)
>>> # One can also store batches of data, each batch being a sub-trajectory >>> env = ParallelEnv(2, lambda: GymEnv("Pendulum-v1")) >>> # Get a batch of [2, 10] -- format is [Batch, Time] >>> rollout = env.rollout(max_steps=10) >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *trajectories* (!) can be stored >>> rb.extend(rollout) >>> # Sample 5 trajectories at random >>> sample = rb.sample(5) >>> assert sample.shape == (5, 10)
如果数据以批次形式传入但需要扁平缓冲区,我们可以简单地在扩展缓冲区之前将数据扁平化
>>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *transitions* can be stored >>> rb.extend(rollout.reshape(-1)) >>> # Sample 5 trajectories at random >>> sample = rb.sample(5) >>> assert sample.shape == (5,)
无法创建沿时间维度扩展的缓冲区,而这通常是使用具有轨迹批次的缓冲区推荐方式。由于轨迹是重叠的,聚合奖励值并进行比较非常困难,如果不是不可能的话。此构造函数无效(注意 ndim 参数)
>>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100, ndim=2), # Breaks! ... writer=TensorDictMaxValueWriter(rank_key="reward") ... )
- add(data: Any) int | torch.Tensor [source]¶
在适当的索引处插入单个数据元素,并返回该索引。
传递给此模块的
rank_key
中的数据应结构化为 []。如果它有更多维度,它将使用reduction
方法将其归约到一个值。