快捷方式

LossModule

class torchrl.objectives.LossModule(*args, **kwargs)[源代码]

RL 损失的父类。

LossModule 继承自 nn.Module。它被设计用于读取一个输入的 TensorDict 并返回另一个 tensordict,其中包含名为 "loss_*" 的损失键。

将损失分解为其组成部分可以被训练器用于在训练过程中记录各种损失值。输出 tensordict 中存在的其他标量也将被记录。

变量:

default_value_estimator – 类的默认值类型。需要值估计的损失会配备一个默认值指针。这个类属性指示了将使用哪个值估计器,如果未指定其他值估计器的话。可以通过 make_value_estimator() 方法更改值估计器。

默认情况下,forward 方法始终使用 gh torchrl.envs.ExplorationType.MEAN 进行装饰。

要利用通过 set_keys() 配置 tensordict 键的能力,子类必须定义一个 _AcceptedKeys dataclass。这个 dataclass 应包含所有打算可配置的键。此外,子类必须实现 :meth:._forward_value_estimator_keys() 方法。此函数对于将任何修改后的 tensordict 键转发到底层 value_estimator 至关重要。

示例

>>> class MyLoss(LossModule):
>>>     @dataclass
>>>     class _AcceptedKeys:
>>>         action = "action"
>>>
>>>     def _forward_value_estimator_keys(self, **kwargs) -> None:
>>>         pass
>>>
>>> loss = MyLoss()
>>> loss.set_keys(action="action2")

注意

当将一个被包装或增强了探索模块的策略传递给 loss 时,我们希望通过 set_exploration_type(<exploration>) 来禁用探索,其中 <exploration> 可以是 ExplorationType.MEANExplorationType.MODEExplorationType.DETERMINISTIC。默认值是 DETERMINISTIC,它通过 deterministic_sampling_mode loss 属性设置。如果需要其他探索模式(或者 DETERMINISTIC 不可用),可以更改此属性的值,这将改变模式。

convert_to_functional(module: TensorDictModule, module_name: str, expand_dim: int | None = None, create_target_params: bool = False, compare_against: list[Parameter] | None = None, **kwargs) None[源代码]

将模块转换为函数式以在损失中使用。

参数:
  • module (TensorDictModule兼容) – 一个有状态的 tensordict 模块。来自此模块的参数将被隔离在 <module_name>_params 属性中,而模块的无状态版本将注册在 module_name 属性下。

  • module_name (str) – 模块将被找到的名称。该模块的参数将在 loss_module.<module_name>_params 下找到,而模块本身将在 loss_module.<module_name> 下找到。

  • expand_dim (int, optional) –

    如果提供,模块的参数将沿第一个维度扩展 N 次,其中 N = expand_dim。当使用具有多个配置的目标网络时,应使用此选项。

    注意

    如果提供了 compare_against 值列表,则生成的参数将只是原始参数的解耦扩展。如果未提供 compare_against,则参数的值将在参数内容的最小值和最大值之间均匀重采样。

  • create_target_params (bool, 可选) – 如果为 True,则参数的解耦副本将可用于为名称为 loss_module.<module_name>_target_params 的目标网络提供输入。如果为 False(默认),此属性仍可用,但它将是参数的解耦实例,而不是副本。换句话说,参数值的任何修改将直接反映在目标参数中。

  • compare_against (参数的可迭代对象, 可选) – 如果提供,此参数列表将用作模块参数的比较集。如果参数被扩展(expand_dim > 0),则模块生成的参数将是原始参数的简单扩展。否则,生成的参数将是原始参数的解耦版本。如果为 None,则生成的参数将按预期携带梯度。

forward(tensordict: TensorDictBase) TensorDictBase[源代码]

它旨在读取一个输入的 TensorDict 并返回另一个包含名为“loss*”的损失键的 tensordict。

将损失分解为其组成部分可以被训练器用于在训练过程中记录各种损失值。输出 tensordict 中存在的其他标量也将被记录。

参数:

tensordict – 一个输入的 tensordict,包含计算损失所需的值。

返回:

一个没有批处理维度的新 tensordict,其中包含各种损失标量,这些标量将被命名为“loss*”。重要的是,损失必须以这个名称返回,因为它们将在反向传播之前被训练器读取。

from_stateful_net(network_name: str, stateful_net: Module)[源代码]

根据有状态的网络版本填充模型的参数。

有关如何收集网络的状态化版本,请参阅 get_stateful_net()

参数:
  • network_name (str) – 要重置的网络名称。

  • stateful_net (nn.Module) – 应从中收集参数的状态化网络。

property functional

模块是否功能化。

除非经过专门设计使其不具有功能性,否则所有损失都具有功能性。

get_stateful_net(network_name: str, copy: bool | None = None)[源代码]

返回网络的状态化版本。

这可用于初始化参数。

这些网络通常开箱即用,无法调用,需要调用 vmap 才能执行。

参数:
  • network_name (str) – 要收集的网络名称。

  • copy (bool, optional) –

    如果为 True,则会进行网络的深拷贝。默认为 True

    注意

    如果模块不是函数式的,则不会进行复制。

make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[源代码]

值函数构造函数。

如果需要非默认值函数,必须使用此方法构建。

参数:
  • value_type (ValueEstimators) – 一个 ValueEstimators 枚举类型,指示要使用的值函数。如果未提供,将使用存储在 default_value_estimator 属性中的默认值。生成的估值器类将注册在 self.value_type 中,以便将来进行改进。

  • **hyperparams – 用于值函数的超参数。如果未提供,将使用 default_value_kwargs() 中指示的值。

示例

>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> # updating the parameters of the default value estimator
>>> dqn_loss.make_value_estimator(gamma=0.9)
>>> dqn_loss.make_value_estimator(
...     ValueEstimators.TD1,
...     gamma=0.9)
>>> # if we want to change the gamma value
>>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
named_parameters(prefix: str = '', recurse: bool = True) Iterator[tuple[str, torch.nn.parameter.Parameter]][源代码]

返回模块参数的迭代器,同时生成参数的名称和参数本身。

参数:
  • prefix (str) – 为所有参数名称添加前缀。

  • recurse (bool) – 如果为 True,则会生成此模块及其所有子模块的参数。否则,仅生成此模块直接成员的参数。

  • remove_duplicate (bool, optional) – 是否在结果中删除重复的参数。默认为 True。

产生:

(str, Parameter) – 包含名称和参数的元组

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
parameters(recurse: bool = True) Iterator[Parameter][源代码]

返回模块参数的迭代器。

这通常传递给优化器。

参数:

recurse (bool) – 如果为 True,则会生成此模块及其所有子模块的参数。否则,仅生成此模块直接成员的参数。

产生:

Parameter – 模块参数

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
reset_parameters_recursive()[源代码]

重置模块的参数。

set_keys(**kwargs) None[源代码]

设置 tensordict 键名。

示例

>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
property value_estimator: ValueEstimatorBase

价值函数将奖励和即将到来的状态/状态-动作对的价值估计值融合到价值网络的目标价值估计中。

property vmap_randomness

Vmap 随机模式。

vmap 的随机性模式控制当处理具有随机结果的函数(如 randn()rand())时,vmap() 应该如何执行。如果设置为 “error”,任何随机函数都将引发异常,表明 vmap 不知道如何处理该随机调用。

如果设置为 “different”,则 vmap 正在调用的批次中的每个元素将表现不同。如果设置为 “same”,则 vmap 会将相同的结果复制到所有元素。

vmap_randomness 默认情况下是 “error”(如果未检测到任何随机模块),而在其他情况下为 “different”。默认情况下,只有有限数量的模块被列为随机模块,但可以使用 add_random_module() 函数来扩展此列表。

此属性支持设置其值。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源