DreamerModelLoss¶
- class torchrl.objectives.DreamerModelLoss(*args, **kwargs)[源码]¶
Dreamer 模型损失。
计算 dreamer 世界模型的损失。该损失由 RSSM 的先验和后验之间的 KL 散度、重构观测值的重构损失以及预测奖励的奖励损失组成。
参考: https://arxiv.org/abs/1912.01603。
- 参数:
world_model (TensorDictModule) – 世界模型。
lambda_kl (
float
, optional) – KL 散度损失的权重。默认为:1.0。lambda_reco (
float
, optional) – 重构损失的权重。默认为:1.0。lambda_reward (
float
, optional) – 奖励损失的权重。默认为:1.0。reco_loss (str, optional) – 重构损失。默认为:“l2”。
reward_loss (str, optional) – 奖励损失。默认为:“l2”。
free_nats (int, optional) – 自由 nat 数。默认为:3。
delayed_clamp (bool, optional) – 如果为
True
,则在平均后进行 KL 夹持。如果为 False(默认),则首先将 KL 散度夹持到自由 nat 值,然后进行平均。global_average (bool, optional) – 如果为
True
,则损失将跨所有维度进行平均。否则,将对所有非批次/时间维度执行求和,并对批次和时间进行平均。默认为 False。
- default_keys¶
别名:
_AcceptedKeys
- forward(tensordict: TensorDict) Tensor [源码]¶
它旨在读取一个输入的 TensorDict 并返回另一个包含名为“loss*”的损失键的 tensordict。
将损失分解为其组成部分可以被训练器用于在训练过程中记录各种损失值。输出 tensordict 中存在的其他标量也将被记录。
- 参数:
tensordict – 一个输入的 tensordict,包含计算损失所需的值。
- 返回:
一个没有批处理维度的新 tensordict,其中包含各种损失标量,这些标量将被命名为“loss*”。重要的是,损失必须以这个名称返回,因为它们将在反向传播之前被训练器读取。