快捷方式

next_state_value

class torchrl.objectives.next_state_value(tensordict: TensorDictBase, operator: TensorDictModule | None = None, next_val_key: str = 'state_action_value', gamma: float = 0.99, pred_next_val: Tensor | None = None, **kwargs)[源代码]

计算下一个状态值(不带梯度),用于计算目标值。

目标值通常用于计算距离损失(例如 MSE)。

L = Sum[ (q_value - target_value)^2 ]

目标值计算方式为:

r + gamma ** n_steps_to_next * value_next_state

如果奖励是即时奖励,则 n_steps_to_next=1。如果使用了 N 步奖励,则 n_steps_to_next 将从输入 tensordict 中获取。

参数:
  • tensordict (TensorDictBase) – 包含奖励和完成键(以及 N 步奖励的 n_steps_to_next 键)的 tensordict。

  • operator (ProbabilisticTDModule, optional) – 值函数算子。调用时应在输入 tensordict 中写入 'next_val_key' 键值对。如果提供了 pred_next_val,则不需要提供此参数。

  • next_val_key (str, optional) – 写入下一个值所用的键。默认值:'state_action_value'

  • gamma (float, optional) – 折扣率。默认值:0.99

  • pred_next_val (Tensor, optional) – 如果下一个状态值不是通过算子计算的,则可以提供。

返回:

一个包含预测值状态的、大小与输入 tensordict 相同的张量。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源