ConsistentDropout¶
- class torchrl.modules.ConsistentDropout(p: float = 0.5)[源码]¶
实现一个具有一致性 Dropout 的
Dropout
变体。该方法在 《Consistent Dropout for Policy Gradient Reinforcement Learning》(Hausknecht & Wagener, 2022) 中提出。
此
Dropout
变体通过在 rollout 期间缓存使用的 dropout 掩码并在更新阶段重用它们,来尝试提高训练稳定性和降低更新方差。您正在查看的类独立于 TorchRL API 的其余部分,并且不需要 tensordict 即可运行。
ConsistentDropoutModule
是ConsistentDropout
的一个包装器,它利用了TensorDict``s 的 可扩展性,将生成的 dropout 掩码存储在 transition ``TensorDict
本身中。请参阅此类以获取详细解释和用法示例。除此之外,与 PyTorch 的
Dropout
实现相比,概念上的偏差很小。- .. note:: TorchRL 的数据收集器在
no_grad()
模式下执行 rollout,但在 eval 模式下不执行, 因此,除非传递给收集器的策略处于 eval 模式,否则 dropout 掩码将被应用。
注意
与其他探索模块不同,
ConsistentDropoutModule
使用train
/eval
模式来遵循 PyTorch 的常规 Dropout API。set_exploration_type()
上下文管理器对此模块没有影响。- 参数:
p (
float
, 可选) – Dropout 概率。默认为0.5
。
另请参阅
MultiSyncDataCollector
: 在底层使用_main_async_collector()
(SyncDataCollector
)
- forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor [源码]¶
在训练(rollouts 和 updates)期间,此调用在与输入张量相乘之前,会遮蔽一个全为 1 的张量。
在评估期间,此调用不执行任何操作,只返回输入。
- 参数:
x (torch.Tensor) – 输入张量。
mask (torch.Tensor, 可选) – Dropout 的可选掩码。
返回: 在训练模式下返回一个张量和相应的掩码,在评估模式下只返回一个张量。
- .. note:: TorchRL 的数据收集器在