快捷方式

SliceSampler

class torchrl.data.replay_buffers.SliceSampler(*, num_slices: int | None = None, slice_len: int | None = None, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: bool | int | tuple[bool | int, bool | int] = False, use_gpu: torch.device | bool = False)[来源]

沿着第一个维度对数据切片进行采样,给定开始和停止信号。

此类以有放回的方式对子轨迹进行采样。无放回版本请参阅 SliceSamplerWithoutReplacement

注意

SliceSampler 检索轨迹索引可能较慢。为了加速其执行,请优先使用 end_key 而不是 traj_key,并考虑以下关键字参数: compilecache_valuesuse_gpu

关键字参数:
  • num_slices (int) – 要采样的切片数量。batch-size 必须大于等于 num_slices 参数。与 slice_len 互斥。

  • slice_len (int) – 要采样的切片的长度。batch-size 必须大于等于 slice_len 参数并且可被其整除。与 num_slices 互斥。

  • end_key (NestedKey, optional) – 指示轨迹(或回合)结束的键。默认为 ("next", "done")

  • traj_key (NestedKey, optional) – 指示轨迹的键。默认为 "episode"(在 TorchRL 的数据集中常用)。

  • ends (torch.Tensor, optional) – 一个包含运行结束信号的 1D 布尔张量。当 end_keytraj_key 获取成本很高,或者该信号易于获取时使用。必须与 cache_values=True 一起使用,并且不能与 end_keytraj_key 结合使用。如果提供,则假定存储已满,并且如果 ends 张量的最后一个元素为 False,则相同的轨迹跨越了结束和开始。

  • trajectories (torch.Tensor, optional) – 一个包含运行 ID 的 1D 整数张量。当 end_keytraj_key 获取成本很高,或者该信号易于获取时使用。必须与 cache_values=True 一起使用,并且不能与 end_keytraj_key 结合使用。如果提供,则假定存储已满,并且如果轨迹张量的最后一个元素与第一个元素相同,则相同的轨迹跨越了结束和开始。

  • cache_values (bool, optional) –

    用于静态数据集。将缓存轨迹的开始和结束信号。即使在调用 extend 的调用期间轨迹索引发生变化,也可以安全地使用此参数,因为此操作将清除缓存。

    警告

    cache_values=True 不能与一个由另一个缓冲区扩展的存储一起使用。例如

    >>> buffer0 = ReplayBuffer(storage=storage,
    ...     sampler=SliceSampler(num_slices=8, cache_values=True),
    ...     writer=ImmutableWriter())
    >>> buffer1 = ReplayBuffer(storage=storage,
    ...     sampler=other_sampler)
    >>> # Wrong! Does not erase the buffer from the sampler of buffer0
    >>> buffer1.extend(data)
    

    警告

    cache_values=True 在缓冲区由多个进程共享,一个进程负责写入而另一个进程负责采样时,将无法按预期工作,因为清除缓存只能在本地进行。

  • truncated_key (NestedKey, optional) – 如果不为 None,则此参数指示在输出数据中写入截断信号的位置。这用于指示价值估计器提供的轨迹在哪里中断。默认为 ("next", "truncated")。此功能仅适用于 TensorDictReplayBuffer 实例(否则截断键将在 sample() 方法返回的信息字典中)。

  • strict_length (bool, optional) – 如果为 False,则长度小于 slice_len(或 batch_size // num_slices)的轨迹将被允许出现在批次中。如果为 True,则长度小于要求的轨迹将被过滤掉。请注意,这可能导致实际 batch_size 小于所请求的!轨迹可以使用 split_trajectories() 进行分割。默认为 True

  • compile (booldict of kwargs, optional) – 如果为 True,则 sample() 方法的瓶颈将使用 compile() 进行编译。关键字参数也可以通过此参数传递给 torch.compile。默认为 False

  • span (bool, int, Tuple[bool | int, bool | int], optional) – 如果提供,则采样的轨迹将跨越左侧和/或右侧。这意味着可能提供的元素少于所需的元素。布尔值表示每个轨迹至少采样一个元素。整数 i 表示每个采样轨迹至少收集 slice_len - i 个样本。使用元组可以对左侧(存储轨迹的开头)和右侧(存储轨迹的结尾)的跨度进行精细控制。

  • use_gpu (booltorch.device) – 如果为 True(或传入了设备),则将使用加速器来检索轨迹开始的索引。这可以显著加速在缓冲区内容很大的情况下进行采样。默认为 False

注意

为了在存储中恢复轨迹分割,SliceSampler 将首先尝试在存储中查找 traj_key 条目。如果找不到,将使用 end_key 来重建回合。

注意

当使用 strict_length=False 时,建议使用 split_trajectories() 来分割采样的轨迹。但是,如果来自同一回合的两个样本相邻放置,这可能会产生不正确的结果。为避免此问题,请考虑以下解决方案之一:

  • 使用带切片采样器的 TensorDictReplayBuffer 实例

    >>> import torch
    >>> from tensordict import TensorDict
    >>> from torchrl.collectors.utils import split_trajectories
    >>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
    >>>
    >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000),
    ...                   sampler=SliceSampler(
    ...                       slice_len=5, traj_key="episode",strict_length=False,
    ...                   ))
    ...
    >>> ep_1 = TensorDict(
    ...     {"obs": torch.arange(100),
    ...     "episode": torch.zeros(100),},
    ...     batch_size=[100]
    ... )
    >>> ep_2 = TensorDict(
    ...     {"obs": torch.arange(4),
    ...     "episode": torch.ones(4),},
    ...     batch_size=[4]
    ... )
    >>> rb.extend(ep_1)
    >>> rb.extend(ep_2)
    >>>
    >>> s = rb.sample(50)
    >>> print(s)
    TensorDict(
        fields={
            episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False),
            index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False),
            next: TensorDict(
                fields={
                    done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                    terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                    truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
                batch_size=torch.Size([46]),
                device=cpu,
                is_shared=False),
            obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([46]),
        device=cpu,
        is_shared=False)
    >>> t = split_trajectories(s, done_key="truncated")
    >>> print(t["obs"])
    tensor([[73, 74, 75, 76, 77],
            [ 0,  1,  2,  3,  0],
            [ 0,  1,  2,  3,  0],
            [41, 42, 43, 44, 45],
            [ 0,  1,  2,  3,  0],
            [67, 68, 69, 70, 71],
            [27, 28, 29, 30, 31],
            [80, 81, 82, 83, 84],
            [17, 18, 19, 20, 21],
            [ 0,  1,  2,  3,  0]])
    >>> print(t["episode"])
    tensor([[0., 0., 0., 0., 0.],
            [1., 1., 1., 1., 0.],
            [1., 1., 1., 1., 0.],
            [0., 0., 0., 0., 0.],
            [1., 1., 1., 1., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [1., 1., 1., 1., 0.]])
    
  • 使用 SliceSamplerWithoutReplacement

    >>> import torch
    >>> from tensordict import TensorDict
    >>> from torchrl.collectors.utils import split_trajectories
    >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
    >>>
    >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
    ...                   sampler=SliceSamplerWithoutReplacement(
    ...                       slice_len=5, traj_key="episode",strict_length=False
    ...                   ))
    ...
    >>> ep_1 = TensorDict(
    ...     {"obs": torch.arange(100),
    ...     "episode": torch.zeros(100),},
    ...     batch_size=[100]
    ... )
    >>> ep_2 = TensorDict(
    ...     {"obs": torch.arange(4),
    ...     "episode": torch.ones(4),},
    ...     batch_size=[4]
    ... )
    >>> rb.extend(ep_1)
    >>> rb.extend(ep_2)
    >>>
    >>> s = rb.sample(50)
    >>> t = split_trajectories(s, trajectory_key="episode")
    >>> print(t["obs"])
    tensor([[75, 76, 77, 78, 79],
            [ 0,  1,  2,  3,  0]])
    >>> print(t["episode"])
    tensor([[0., 0., 0., 0., 0.],
            [1., 1., 1., 1., 0.]])
    

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
>>> torch.manual_seed(0)
>>> rb = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(1_000_000),
...     sampler=SliceSampler(cache_values=True, num_slices=10),
...     batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
...     {
...         "episode": episode,
...         "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
...         "act": torch.randn((20,)).expand(1000, 20),
...         "other": torch.randn((20, 50)).expand(1000, 20, 50),
...     }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> print("sample:", sample)
>>> print("episodes", sample.get("episode").unique())
episodes tensor([1, 2, 3, 4], dtype=torch.int32)

SliceSampler 与大多数 TorchRL 数据集默认兼容

示例

>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSampler
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
>>> for batch in data:
...     batch = batch.reshape(num_slices, -1)
...     break
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
check that each batch only has one episode: tensor([[19],
        [14],
        [ 8],
        [10],
        [13],
        [ 4],
        [ 2],
        [ 3],
        [22],
        [ 8]])

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源