LossModule¶
- class torchrl.objectives.LossModule(*args, **kwargs)[source]¶
A parent class for RL losses.
LossModule inherits from nn.Module. It is designed to read an input TensorDict and return another tensordict with loss keys named
"loss_*"
.将损失分解为其组成部分可以被训练器用于在训练过程中记录各种损失值。输出 tensordict 中存在的其他标量也将被记录。
- 变量:
default_value_estimator – The default value type of the class. Losses that require a value estimation are equipped with a default value pointer. This class attribute indicates which value estimator will be used if none other is specified. The value estimator can be changed using the
make_value_estimator()
method.
By default, the forward method is always decorated with a gh
torchrl.envs.ExplorationType.MEAN
To utilize the ability configuring the tensordict keys via
set_keys()
a subclass must define an _AcceptedKeys dataclass. This dataclass should include all keys that are intended to be configurable. In addition, the subclass must implement the :meth:._forward_value_estimator_keys() method. This function is crucial for forwarding any altered tensordict keys to the underlying value_estimator.示例
>>> class MyLoss(LossModule): >>> @dataclass >>> class _AcceptedKeys: >>> action = "action" >>> >>> def _forward_value_estimator_keys(self, **kwargs) -> None: >>> pass >>> >>> loss = MyLoss() >>> loss.set_keys(action="action2")
注意
When a policy that is wrapped or augmented with an exploration module is passed to the loss, we want to deactivate the exploration through
set_exploration_type(<exploration>)
where<exploration>
is eitherExplorationType.MEAN
,ExplorationType.MODE
orExplorationType.DETERMINISTIC
. The default value isDETERMINISTIC
and it is set through thedeterministic_sampling_mode
loss attribute. If another exploration mode is required (or ifDETERMINISTIC
is not available), one can change the value of this attribute which will change the mode.- convert_to_functional(module: TensorDictModule, module_name: str, expand_dim: int | None = None, create_target_params: bool = False, compare_against: list[Parameter] | None = None, **kwargs) None [source]¶
将模块转换为函数式以在损失中使用。
- 参数:
module (TensorDictModule or compatible) – a stateful tensordict module. Parameters from this module will be isolated in the <module_name>_params attribute and a stateless version of the module will be registered under the module_name attribute.
module_name (str) – name where the module will be found. The parameters of the module will be found under
loss_module.<module_name>_params
whereas the module will be found underloss_module.<module_name>
.expand_dim (int, optional) –
- 如果提供,则模块的参数
will be expanded
N
times, whereN = expand_dim
along the first dimension. This option is to be used whenever a target network with more than one configuration is to be used.注意
If a
compare_against
list of values is provided, the resulting parameters will simply be a detached expansion of the original parameters. Ifcompare_against
is not provided, the value of the parameters will be resampled uniformly between the minimum and maximum value of the parameter content.- create_target_params (bool, optional): 如果为
True
,则会创建一个解耦的 copy of the parameter will be available to feed a target network under the name
loss_module.<module_name>_target_params
. IfFalse
(default), this attribute will still be available but it will be a detached instance of the parameters, not a copy. In other words, any modification of the parameter value will directly be reflected in the target parameters.
compare_against (iterable of parameters, optional) – if provided, this list of parameters will be used as a comparison set for the parameters of the module. If the parameters are expanded (
expand_dim > 0
), the resulting parameters for the module will be a simple expansion of the original parameter. Otherwise, the resulting parameters will be a detached version of the original parameters. IfNone
, the resulting parameters will carry gradients as expected.
- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
它旨在读取一个输入的 TensorDict 并返回另一个包含名为“loss*”的损失键的 tensordict。
将损失分解为其组成部分可以被训练器用于在训练过程中记录各种损失值。输出 tensordict 中存在的其他标量也将被记录。
- 参数:
tensordict – 一个输入的 tensordict,包含计算损失所需的值。
- 返回:
一个没有批处理维度的新 tensordict,其中包含各种损失标量,这些标量将被命名为“loss*”。重要的是,损失必须以这个名称返回,因为它们将在反向传播之前被训练器读取。
- from_stateful_net(network_name: str, stateful_net: Module)[source]¶
根据有状态的网络版本填充模型的参数。
See
get_stateful_net()
for details on how to gather a stateful version of the network.- 参数:
network_name (str) – 要重置的网络名称。
stateful_net (nn.Module) – 应从中收集参数的状态化网络。
- property functional¶
模块是否功能化。
除非经过专门设计使其不具有功能性,否则所有损失都具有功能性。
- get_stateful_net(network_name: str, copy: bool | None = None)[source]¶
返回网络的状态化版本。
这可用于初始化参数。
这些网络通常开箱即用,无法调用,需要调用 vmap 才能执行。
- 参数:
network_name (str) – 要收集的网络名称。
copy (bool, optional) –
如果为
True
,则会进行网络的深拷贝。默认为True
。注意
如果模块不是函数式的,则不会进行复制。
- make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[source]¶
值函数构造函数。
如果需要非默认值函数,必须使用此方法构建。
- 参数:
value_type (ValueEstimators) – A
ValueEstimators
enum type indicating the value function to use. If none is provided, the default stored in thedefault_value_estimator
attribute will be used. The resulting value estimator class will be registered inself.value_type
, allowing future refinements.**hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by
default_value_kwargs()
will be used.
示例
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> # updating the parameters of the default value estimator >>> dqn_loss.make_value_estimator(gamma=0.9) >>> dqn_loss.make_value_estimator( ... ValueEstimators.TD1, ... gamma=0.9) >>> # if we want to change the gamma value >>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[tuple[str, torch.nn.parameter.Parameter]] [source]¶
返回模块参数的迭代器,同时生成参数的名称和参数本身。
- 参数:
prefix (str) – 为所有参数名称添加前缀。
recurse (bool) – 如果为 True,则会生成此模块及其所有子模块的参数。否则,仅生成此模块直接成员的参数。
remove_duplicate (bool, optional) – 是否在结果中删除重复的参数。默认为 True。
- 产生:
(str, Parameter) – 包含名称和参数的元组
示例
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- parameters(recurse: bool = True) Iterator[Parameter] [source]¶
返回模块参数的迭代器。
这通常传递给优化器。
- 参数:
recurse (bool) – 如果为 True,则会生成此模块及其所有子模块的参数。否则,仅生成此模块直接成员的参数。
- 产生:
Parameter – 模块参数
示例
>>> # xdoctest: +SKIP("undefined vars") >>> for param in model.parameters(): >>> print(type(param), param.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- set_keys(**kwargs) None [source]¶
设置 tensordict 键名。
示例
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
- property value_estimator: ValueEstimatorBase¶
价值函数将奖励和即将到来的状态/状态-动作对的价值估计值融合到价值网络的目标价值估计中。
- property vmap_randomness¶
Vmap 随机模式。
The vmap randomness mode controls what
vmap()
should do when dealing with functions with a random outcome such asrandn()
andrand()
. If “error”, any random function will raise an exception indicating that vmap does not know how to handle the random call.如果设置为 “different”,则 vmap 正在调用的批次中的每个元素将表现不同。如果设置为 “same”,则 vmap 会将相同的结果复制到所有元素。
vmap_randomness
defaults to “error” if no random module is detected, and to “different” in other cases. By default, only a limited number of modules are listed as random, but the list can be extended using theadd_random_module()
function.此属性支持设置其值。