next_state_value¶
- class torchrl.objectives.next_state_value(tensordict: TensorDictBase, operator: TensorDictModule | None = None, next_val_key: str = 'state_action_value', gamma: float = 0.99, pred_next_val: Tensor | None = None, **kwargs)[源代码]¶
计算下一个状态的值(不带梯度),用于计算目标值。
- 目标值通常用于计算距离损失(例如 MSE)。
L = Sum[ (q_value - target_value)^2 ]
- 目标值计算如下:
r + gamma ** n_steps_to_next * value_next_state
如果奖励是即时奖励,n_steps_to_next=1。如果使用 N 步奖励,n_steps_to_next 将从输入 tensordict 中收集。
- 参数:
tensordict (TensorDictBase) – 包含奖励和完成键(以及 N 步奖励的 n_steps_to_next 键)的 Tensordict。
operator (ProbabilisticTDModule, optional) – 值函数算子。调用时应在输入 tensordict 中写入 ‘next_val_key’ 键值。如果提供了 pred_next_val,则无需提供此参数。
next_val_key (str, optional) – 将写入下一个值的键。默认为 ‘state_action_value’。
gamma (
float
, optional) – 回报折扣率。默认为 0.99。pred_next_val (Tensor, optional) – 如果下一个状态值不是通过算子计算的,则可以提供。
- 返回:
一个与输入 tensordict 大小相同的张量,包含预测的值状态。