快捷方式

PrioritizedSliceSampler

class torchrl.data.replay_buffers.PrioritizedSliceSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: torch.dtype = torch.float32, reduction: str = 'max', *, num_slices: int | None = None, slice_len: int | None = None, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: bool | int | tuple[bool | int, bool | int] = False, max_priority_within_buffer: bool = False)[源码]

使用优先采样,根据开始和停止信号,采样数据沿第一个维度的切片。

此类结合了轨迹采样和优先经验回放 (PER),如“Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay.” 中所述(https://arxiv.org/abs/1511.05952)。

核心思想:此采样器不均匀地采样轨迹切片,而是根据轨迹切片中转换的重要性来优先选择轨迹的起始点。这使得学习能够专注于轨迹中最具信息量的部分。

工作原理:1. 每个转换根据其 TD 误差被分配一个优先级:\(p_i = |\\delta_i| + \\epsilon\) 2. 轨迹起始点以以下概率采样:\(P(i) = \frac{p_i^\alpha}{\\sum_j p_j^\alpha}\) 3. 重要性采样权重用于纠正偏差:\(w_i = (N \\cdot P(i))^{-\beta}\) 4. 从采样到的起始点提取完整的轨迹切片。

有关更多信息,请参阅 SliceSamplerPrioritizedSampler

警告

PrioritizedSliceSampler 将查看单个转换的优先级,并相应地对起始点进行采样。这意味着优先级较低的转换也可能出现在样本中,如果它们紧随另一个优先级较高的转换之后;而优先级很高但接近轨迹末尾的转换,如果不能用作起始点,则可能永远不会被采样。目前,用户有责任使用 update_priority() 来聚合轨迹项的优先级。

参数:
  • max_capacity (int) – 缓冲区的最大容量。

  • alpha (float) – 指数 \(\alpha\) 决定了优先级的程度。 - \(\alpha = 0\):轨迹起始点的均匀采样 - \(\alpha = 1\):基于起始点处 TD 误差幅度的完全优先级 - 典型值:0.4-0.7 以实现平衡的优先级 - 较高的 \(\alpha\) 意味着对高误差轨迹区域的优先级更高。

  • beta (float) – 重要性采样的负指数 \(\beta\)。 - \(\beta\) 控制对优先级引入的偏差的校正 - \(\beta = 0\):无校正(偏向高优先级轨迹区域) - \(\beta = 1\):完全校正(无偏差但可能不稳定) - 典型值:开始时为 0.4-0.6,训练期间逐渐退火至 1.0 - 训练早期较低的 \(\beta\) 可提供稳定性,后期较高的 \(\beta\) 可减少偏差。

  • eps (float, optional) – 添加到优先级的微小常数,以确保没有转换的优先级为零。这可以防止轨迹区域永远不被采样。默认为 1e-8。

  • reduction (str, optional) – 多维 tensordicts(即存储的轨迹)的缩减方法。可以是“max”、“min”、“median”或“mean”之一。

参数指南: - :math:`alpha` (alpha):控制对高误差轨迹区域的优先级程度

  • 0.4-0.7:学习速度和稳定性之间的良好平衡

  • 1.0:最大优先级(可能不稳定)

  • 0.0:均匀采样(无优先级优势)

  • :math:`beta` (beta):控制重要性采样校正
    • 训练初期设置为 0.4-0.6 以获得稳定性

    • 训练过程中退火至 1.0 以减少偏差

    • 较低的值 = 更稳定但有偏差

    • 较高的值 = 偏差较小但可能不稳定

  • :math:`\epsilon`:防止优先级为零的小常数
    • 1e-8:良好的默认值

    • 太小:可能导致数值问题

    • 太大:降低优先级效果

关键字参数:
  • num_slices (int) – 要抽样的切片数量。批量大小必须大于或等于 num_slices 参数。与 slice_len 互斥。

  • slice_len (int) – 要抽样的切片的长度。批量大小必须大于或等于 slice_len 参数,并且可以被其整除。与 num_slices 互斥。

  • end_key (NestedKey, optional) – 指示轨迹(或回合)结束的键。默认为 ("next", "done")

  • traj_key (NestedKey, optional) – 指示轨迹的键。默认为 "episode"(在 TorchRL 的数据集中常用)。

  • ends (torch.Tensor, optional) – 一个一维布尔张量,包含运行结束信号。当 end_keytraj_key 的获取成本很高,或者该信号很容易获得时使用。必须与 cache_values=True 一起使用,并且不能与 end_keytraj_key 结合使用。

  • trajectories (torch.Tensor, optional) – 一个一维整数张量,包含运行 ID。当 end_keytraj_key 的获取成本很高,或者该信号很容易获得时使用。必须与 cache_values=True 一起使用,并且不能与 end_keytraj_key 结合使用。

  • cache_values (bool, optional) –

    用于静态数据集。将缓存轨迹的开始和结束信号。即使在调用 extend 期间轨迹索引发生更改,也可以安全地使用此选项,因为此操作将清除缓存。

    警告

    cache_values=True 在采样器与由另一个缓冲区扩展的存储一起使用时将无法正常工作。例如:

    >>> buffer0 = ReplayBuffer(storage=storage,
    ...     sampler=SliceSampler(num_slices=8, cache_values=True),
    ...     writer=ImmutableWriter())
    >>> buffer1 = ReplayBuffer(storage=storage,
    ...     sampler=other_sampler)
    >>> # Wrong! Does not erase the buffer from the sampler of buffer0
    >>> buffer1.extend(data)
    

    警告

    cache_values=True 在缓冲区由多个进程共享,一个进程负责写入而另一个进程负责采样时,将无法按预期工作,因为清除缓存只能在本地进行。

  • truncated_key (NestedKey, optional) – 如果不为 None,则此参数指示截断信号应写入输出数据的位置。这用于告知值估计器提供的轨迹在哪里中断。默认为 ("next", "truncated")。此功能仅适用于 TensorDictReplayBuffer 实例(否则,截断键将在 sample() 方法返回的信息字典中)。

  • strict_length (bool, optional) – 如果为 False,则允许长度小于 slice_len(或 batch_size // num_slices)的轨迹出现在批次中。如果为 True,则将过滤掉长度不足的轨迹。请注意,这可能导致有效的 batch_size 短于请求的 batch_size!可以使用 split_trajectories() 来拆分轨迹。默认为 True

  • compile (booldict of kwargs, optional) – 如果为 True,则 sample() 方法的瓶颈将使用 compile() 进行编译。也可以通过此参数将关键字参数传递给 torch.compile。默认为 False

  • span (bool, int, Tuple[bool | int, bool | int], optional) – 如果提供,则采样的轨迹将跨越左侧和/或右侧。这意味着可能提供的元素少于所需元素。布尔值表示每个轨迹至少采样一个元素。整数 i 表示为每个采样的轨迹收集的样本至少为 slice_len - i。使用元组可以精细控制左侧(存储轨迹的开头)和右侧(存储轨迹的结尾)的跨度。

  • max_priority_within_buffer (bool, optional) – 如果为 True,则在缓冲区内跟踪最大优先级。如果为 False,则最大优先级跟踪自采样器实例化以来的最大值。默认为 False

示例

>>> import torch
>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer, LazyMemmapStorage, PrioritizedSliceSampler
>>> from tensordict import TensorDict
>>> sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
>>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6)
>>> data = TensorDict(
...     {
...         "observation": torch.randn(9,16),
...         "action": torch.randn(9, 1),
...         "episode": torch.tensor([0,0,0,1,1,1,2,2,2], dtype=torch.long),
...         "steps": torch.tensor([0,1,2,0,1,2,0,1,2], dtype=torch.long),
...         ("next", "observation"): torch.randn(9,16),
...         ("next", "reward"): torch.randn(9,1),
...         ("next", "done"): torch.tensor([0,0,1,0,0,1,0,0,1], dtype=torch.bool).unsqueeze(1),
...     },
...     batch_size=[9],
... )
>>> rb.extend(data)
>>> sample, info = rb.sample(return_info=True)
>>> print("episode", sample["episode"].tolist())
episode [2, 2, 2, 2, 1, 1]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 1, 2]
>>> print("weight", info["_weight"].tolist())
weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
>>> priority = torch.tensor([0,3,3,0,0,0,1,1,1])
>>> rb.update_priority(torch.arange(0,9,1), priority=priority)
>>> sample, info = rb.sample(return_info=True)
>>> print("episode", sample["episode"].tolist())
episode [2, 2, 2, 2, 2, 2]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 0, 1]
>>> print("weight", info["_weight"].tolist())
weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
update_priority(index: int | torch.Tensor, priority: float | torch.Tensor, *, storage: TensorStorage | None = None) None

更新由索引指向的数据的优先级。

参数:
  • index (inttorch.Tensor) – 要更新优先级的索引。

  • priority (Numbertorch.Tensor) – 索引元素的新的优先级。

关键字参数:

storage (Storage, optional) – 一个存储,用于将 N 维索引大小映射到 sum_tree 和 min_tree 的一维大小。仅在 index.ndim > 2 时需要。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源