快捷方式

ConsistentDropout

class torchrl.modules.ConsistentDropout(p: float = 0.5)[源代码]

实现了一个 Dropout 变体,具有一致性 dropout。

该方法在 “Consistent Dropout for Policy Gradient Reinforcement Learning” (Hausknecht & Wagener, 2022) 中提出。

这个 Dropout 变体试图通过在 rollout 期间缓存 dropout 掩码并在更新阶段重用它们来提高训练稳定性和减少更新方差。

您正在查看的类独立于 TorchRL 的其余 API,并且不需要 tensordict 即可运行。 ConsistentDropoutModuleConsistentDropout 的包装器,它利用了 TensorDict 的可扩展性,通过 生成的 dropout 掩码 存储在 transition ``TensorDict 本身中。有关详细说明和用法示例,请参阅此类。

除此之外,与 PyTorch 的 Dropout 实现相比,概念上的偏差很小。

..note:: TorchRL 的数据收集器在 no_grad() 模式下执行 rollout,但不在 eval 模式下执行,

因此,除非传递给收集器的策略处于 eval 模式,否则将应用 dropout 掩码。

注意

与其他探索模块不同,ConsistentDropoutModule 使用 train/eval 模式以符合 PyTorch 中常规的 Dropout API。 set_exploration_type() 上下文管理器对此模块无效。

参数:

p (float, 可选) – Dropout 概率。默认为 0.5

另请参阅

forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[源代码]

在训练(rollouts & updates)期间,此调用在乘以输入张量之前,会掩盖一个全为 1 的张量。

在评估期间,此调用将不执行任何操作,仅返回输入。

参数:

返回: 在训练模式下返回一个张量和一个对应的掩码,在评估模式下仅返回一个张量。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源