快捷方式

PrioritizedSampler

class torchrl.data.replay_buffers.PrioritizedSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: dtype = torch.float32, reduction: str = 'max', max_priority_within_buffer: bool = False)[源代码]

经验回放的优先采样器。

此采样器实现了优先经验回放 (PER),如“Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay.” 所述 (https://arxiv.org/abs/1511.05952)。

核心思想:PER 不再从经验回放缓冲区中统一采样经验,而是根据其“重要性”(通常由其时序差 (TD) 误差的大小衡量)的概率来采样经验。这种优先排序可以通过关注最具信息量的经验来加速学习。

工作原理:1. 每个经验根据其 TD 误差分配优先级:\(p_i = |\delta_i| + \epsilon\) 2. 采样概率计算如下:\(P(i) = \frac{p_i^\alpha}{\sum_j p_j^\alpha}\) 3. 重要性采样权重用于纠正偏差:\(w_i = (N \cdot P(i))^{-\beta}\)

参数:
  • max_capacity (int) – 缓冲区的最大容量。

  • alpha (float) – 指数 \(\alpha\) 决定了优先排序的程度。 - \(\alpha = 0\):统一采样(无优先排序) - \(\alpha = 1\):基于 TD 误差大小的完全优先排序 - 典型值:0.4-0.7,用于平衡优先排序 - 较高的 \(\alpha\) 值意味着对高误差经验的优先排序更激进

  • beta (float) – 重要性采样的负指数 \(\beta\)。 - \(\beta\) 控制对优先排序引入的偏差的纠正 - \(\beta = 0\):无纠正(偏向于高优先级样本) - \(\beta = 1\):完全纠正(无偏但可能不稳定) - 典型值:在训练期间从 0.4-0.6 开始,然后退火到 1.0 - 训练早期较低的 \(\beta\) 值提供稳定性,后期较高的 \(\beta\) 值减少偏差

  • eps (float, 可选) – 添加到优先级的微小常数,以确保没有任何经验的优先级为零。这可以防止某些经验永远不被采样。默认为 1e-8。

  • reduction (str, 可选) – 用于多维 tensordicts(即存储的轨迹)的缩减方法。可以是 "max"、"min"、"median" 或 "mean" 之一。

  • max_priority_within_buffer (bool, 可选) – 如果为 True,则在缓冲区内跟踪最大优先级。如果为 False,则最大优先级会跟踪自采样器实例化以来的最大值。

参数指南: - :math:`alpha` (alpha):控制优先排序高误差经验的程度

  • 0.4-0.7:学习速度和稳定性之间的良好平衡

  • 1.0:最大优先排序(可能不稳定)

  • 0.0:统一采样(无优先排序益处)

  • :math:`beta` (beta):控制重要性采样纠正
    • 从 0.4-0.6 开始进行训练以获得稳定性

    • 在训练过程中退火到 1.0 以减少偏差

    • 较低的值 = 更稳定但有偏差

    • 较高的值 = 偏差较小但可能不稳定

  • :math:`epsilon`:防止优先级为零的小常数
    • 1e-8:良好的默认值

    • 太小:可能导致数值问题

    • 太大:会降低优先排序的效果

示例

>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> rb.add(data_0)
>>> rb.add(data_1)
>>> rb.update_priority(torch.tensor([0, 1]), priority=priority)
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample)
TensorDict(
        fields={
            action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
            obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
            priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
            reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([10]),
        device=cpu,
        is_shared=False)
>>> print(info)
{'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
       1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}

注意

使用 TensorDictReplayBuffer 可以平滑更新优先级的过程

>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = TDRB(
...     storage=LazyTensorStorage(10),
...     sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
...     priority_key="priority",  # This kwarg isn't present in regular RBs
... )
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> data = torch.stack([data_0, data_1])
>>> rb.extend(data)
>>> rb.update_priority(data)  # Reads the "priority" key as indicated in the constructor
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample['index'])  # The index is packed with the tensordict
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
update_priority(index: int | torch.Tensor, priority: float | torch.Tensor, *, storage: TensorStorage | None = None) None[源代码]

更新由索引指向的数据的优先级。

参数:
  • index (inttorch.Tensor) – 要更新优先级的索引。

  • priority (Numbertorch.Tensor) – 索引元素的新的优先级。

关键字参数:

storage (Storage, 可选) – 一个用于将 Nd 索引大小映射到 sum_tree 和 min_tree 的一维大小的存储。只有当 index.ndim > 2 时才需要。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源