快捷方式

SliceSamplerWithoutReplacement

class torchrl.data.replay_buffers.SliceSamplerWithoutReplacement(*, num_slices: int | None = None, slice_len: int | None = None, drop_last: bool = False, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, shuffle: bool = True, compile: bool | dict = False, use_gpu: bool | torch.device = False)[来源]

根据开始和停止信号,无放回地沿第一个维度对数据切片进行采样。

在此上下文中,“无放回”意味着在计数器被自动重置之前,同一个元素(不是轨迹)不会被采样两次。然而,在单个样本中,同一轨迹的切片最多只会出现一次(请参阅下面的示例)。

此类应与静态重放缓冲区一起使用,或在两个重放缓冲区扩展之间使用。扩展重放缓冲区将重置采样器,并且当前不允许连续无放回采样。

注意

SliceSamplerWithoutReplacement 在检索轨迹索引时可能会很慢。为了加速其执行,请优先使用 end_key 而不是 traj_key,并考虑以下关键字参数:compilecache_valuesuse_gpu

关键字参数:
  • drop_last (bool, optional) – 如果为 True,则会丢弃最后一个不完整的样本(如果有)。如果为 False,则会保留最后一个样本。默认为 False

  • num_slices (int) – 要采样的切片数量。批次大小必须大于或等于 num_slices 参数。与 slice_len 互斥。

  • slice_len (int) – 要采样的切片的长度。批次大小必须大于或等于 slice_len 参数,并且可以被其整除。与 num_slices 互斥。

  • end_key (NestedKey, optional) – 指示轨迹(或回合)结束的键。默认为 ("next", "done")

  • traj_key (NestedKey, optional) – 指示轨迹的键。默认为 "episode"(在 TorchRL 的数据集中常用)。

  • ends (torch.Tensor, optional) – 一个包含运行结束信号的 1D 布尔张量。当 end_keytraj_key 获取成本较高,或当此信号易于获得时使用。必须与 cache_values=True 一起使用,并且不能与 end_keytraj_key 结合使用。

  • trajectories (torch.Tensor, optional) – 一个包含运行 ID 的 1D 整型张量。当 end_keytraj_key 获取成本较高,或当此信号易于获得时使用。必须与 cache_values=True 一起使用,并且不能与 end_keytraj_key 结合使用。

  • truncated_key (NestedKey, optional) – 如果不为 None,此参数指示在哪里将截断信号写入输出数据。这用于向值估计器指示提供的轨迹在哪里中断。默认为 ("next", "truncated")。此功能仅适用于 TensorDictReplayBuffer 实例(否则,截断键会在 sample() 方法返回的信息字典中)。

  • strict_length (bool, optional) – 如果为 False,则允许批次中出现长度小于 slice_len(或 batch_size // num_slices)的轨迹。如果为 True,则会过滤掉长度不足的轨迹。请注意,这可能导致实际 batch_size 短于要求的!轨迹可以使用 split_trajectories() 进行拆分。默认为 True

  • shuffle (bool, optional) – 如果为 False,则不打乱轨迹的顺序。默认为 True

  • compile (bool or dict of kwargs, optional) – 如果为 True,则 sample() 方法的瓶颈将使用 compile() 进行编译。也可以通过此参数将关键字参数传递给 torch.compile。默认为 False

  • use_gpu (bool or torch.device) – 如果为 True(或传递了设备),则将使用加速器来检索轨迹开始的索引。当缓冲区内容很大时,这可以显著加速采样。默认为 False

注意

为了恢复存储中的轨迹分割,SliceSamplerWithoutReplacement 将首先尝试在存储中查找 traj_key 条目。如果找不到,将使用 end_key 来重建回合。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSamplerWithoutReplacement
>>>
>>> rb = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(1000),
...     # asking for 10 slices for a total of 320 elements, ie, 10 trajectories of 32 transitions each
...     sampler=SliceSamplerWithoutReplacement(num_slices=10),
...     batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
...     {
...         "episode": episode,
...         "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
...         "act": torch.randn((20,)).expand(1000, 20),
...         "other": torch.randn((20, 50)).expand(1000, 20, 50),
...     }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> # since we want trajectories of 32 transitions but there are only 4 episodes to
>>> # sample from, we only get 4 x 32 = 128 transitions in this batch
>>> print("sample:", sample)
>>> print("trajectories in sample", sample.get("episode").unique())

SliceSamplerWithoutReplacement 与大多数 TorchRL 的数据集默认兼容,并允许用户以类似数据加载器的方式消费数据集。

示例

>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSamplerWithoutReplacement
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320,
...     sampler=SliceSamplerWithoutReplacement(num_slices=num_slices))
>>> # the last sample is kept, since drop_last=False by default
>>> for i, batch in enumerate(data):
...     print(batch.get("episode").unique())
tensor([ 5,  6,  8, 11, 12, 14, 16, 17, 19, 24])
tensor([ 1,  2,  7,  9, 10, 13, 15, 18, 21, 22])
tensor([ 0,  3,  4, 20, 23])

当请求大量总样本,但轨迹较少且跨度较小时,批次最多只会包含每个轨迹的一个样本。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.collectors.utils import split_trajectories
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
>>>
>>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
...                   sampler=SliceSamplerWithoutReplacement(
...                       slice_len=5, traj_key="episode",strict_length=False
...                   ))
...
>>> ep_1 = TensorDict(
...     {"obs": torch.arange(100),
...     "episode": torch.zeros(100),},
...     batch_size=[100]
... )
>>> ep_2 = TensorDict(
...     {"obs": torch.arange(51),
...     "episode": torch.ones(51),},
...     batch_size=[51]
... )
>>> rb.extend(ep_1)
>>> rb.extend(ep_2)
>>>
>>> s = rb.sample(50)
>>> t = split_trajectories(s, trajectory_key="episode")
>>> print(t["obs"])
tensor([[14, 15, 16, 17, 18],
        [ 3,  4,  5,  6,  7]])
>>> print(t["episode"])
tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.]])
>>>
>>> s = rb.sample(50)
>>> t = split_trajectories(s, trajectory_key="episode")
>>> print(t["obs"])
tensor([[ 4,  5,  6,  7,  8],
        [26, 27, 28, 29, 30]])
>>> print(t["episode"])
tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.]])

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源