PrioritizedSampler¶
- class torchrl.data.replay_buffers.PrioritizedSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: dtype = torch.float32, reduction: str = 'max', max_priority_within_buffer: bool = False)[源代码]¶
用于回放缓冲区(replay buffer)的优先采样器。
此采样器实现了“优先经验回放”(PER),如“Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay.”中所述(https://arxiv.org/abs/1511.05952)
核心思想:与从回放缓冲区中均匀采样经验不同,PER 根据经验的“重要性”——通常通过其时序差分(TD)误差的大小来衡量——以一定概率进行采样。这种优先排序可以通过关注信息量最大的经验来加快学习速度。
工作原理:1. 每个经验根据其 TD 误差被赋予一个优先级:\(p_i = |\delta_i| + \epsilon\) 2. 采样概率计算方式为:\(P(i) = \frac{p_i^\alpha}{\sum_j p_j^\alpha}\) 3. 重要性采样权重用于纠正偏差:\(w_i = (N \cdot P(i))^{-\beta}\)
- 参数:
max_capacity (int) – 缓冲区的最大容量。
alpha (
float
) – 指数 \(\alpha\) 决定了优先排序的使用程度。 - \(\alpha = 0\):均匀采样(无优先排序) - \(\alpha = 1\):基于 TD 误差大小的完全优先排序 - 通常值:0.4-0.7 用于平衡优先排序 - 较高的 \(\alpha\) 值意味着对高误差经验进行更积极的优先排序beta (
float
) – 重要性采样的负指数 \(\beta\)。 - \(\beta\) 控制对优先排序引入的偏差的纠正 - \(\beta = 0\):无纠正(倾向于高优先级样本) - \(\beta = 1\):完全纠正(无偏差但可能不稳定) - 通常值:训练开始时从 0.4-0.6 开始,在训练过程中退火到 1.0 - 训练早期较低的 \(\beta\) 值提供稳定性,较高的 \(\beta\) 值减少偏差eps (
float
, 可选) – 添加到优先级的微小常数,以确保没有任何经验的优先级为零。这可以防止经验永远不被采样。默认为 1e-8。reduction (str, 可选) – 多维张量(即存储的轨迹)的归约方法。可以是 "max", "min", "median" 或 "mean" 之一。
max_priority_within_buffer (bool, 可选) – 如果为
True
,则在缓冲区内跟踪最大优先级。如果为False
,则最大优先级跟踪自采样器实例化以来的最大值。
参数指南: - :math:`alpha` (alpha): 控制对高误差经验的优先排序程度
0.4-0.7:学习速度和稳定性之间的良好平衡
1.0:最大优先排序(可能不稳定)
0.0:均匀采样(无优先排序效果)
- :math:`beta` (beta): 控制重要性采样纠正
从 0.4-0.6 开始训练以提高稳定性
在训练过程中退火至 1.0 以减少偏差
较低的值 = 更稳定但有偏差
较高的值 = 偏差较小但可能不稳定
- :math:`epsilon`: 防止优先级为零的微小常数
1e-8:良好的默认值
太小:可能导致数值问题
太大:会降低优先排序的效果
示例
>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler >>> from tensordict import TensorDict >>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0)) >>> priority = torch.tensor([0, 1000]) >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) >>> rb.add(data_0) >>> rb.add(data_1) >>> rb.update_priority(torch.tensor([0, 1]), priority=priority) >>> sample, info = rb.sample(10, return_info=True) >>> print(sample) TensorDict( fields={ action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([10]), device=cpu, is_shared=False) >>> print(info) {'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}
注意
使用
TensorDictReplayBuffer
可以平滑地更新优先级。>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler >>> from tensordict import TensorDict >>> rb = TDRB( ... storage=LazyTensorStorage(10), ... sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0), ... priority_key="priority", # This kwarg isn't present in regular RBs ... ) >>> priority = torch.tensor([0, 1000]) >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) >>> data = torch.stack([data_0, data_1]) >>> rb.extend(data) >>> rb.update_priority(data) # Reads the "priority" key as indicated in the constructor >>> sample, info = rb.sample(10, return_info=True) >>> print(sample['index']) # The index is packed with the tensordict tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
- update_priority(index: int | torch.Tensor, priority: float | torch.Tensor, *, storage: TensorStorage | None = None) None [源代码]¶
更新由索引指向的数据的优先级。
- 参数:
index (int 或 torch.Tensor) – 要更新的优先级的索引。
priority (Number 或 torch.Tensor) – 被索引元素的新的优先级。
- 关键字参数:
storage (Storage, 可选) – 一个用于将 Nd 索引大小映射到 sum_tree 和 min_tree 的 1d 大小的存储。仅在
index.ndim > 2
时需要。