ValueEstimatorBase¶
- class torchrl.objectives.value.ValueEstimatorBase(*args, **kwargs)[源代码]¶
价值函数模块的抽象父类。
其
ValueFunctionBase.forward()
方法将计算值(由价值网络给出)和价值估计(由价值估计器给出)以及优势,并将这些值写入输出 tensordict。如果只需要价值估计,则应改用
ValueFunctionBase.value_estimate()
。- default_keys¶
别名:
_AcceptedKeys
- abstract forward(tensordict: TensorDictBase, *, params: TensorDictBase | None = None, target_params: TensorDictBase | None = None) TensorDictBase [源代码]¶
给定 tensordict 中的数据,计算优势估计。
If a functional module is provided, a nested TensorDict containing the parameters (and if relevant the target parameters) can be passed to the module.
- 参数:
tensordict (TensorDictBase) – 一个包含数据的 TensorDict(一个观测键,
"action"
,("next", "reward")
,("next", "done")
,("next", "terminated")
,以及由环境返回的"next"
tensordict 状态),这些数据用于计算价值估计和 TD 估计。传递给此模块的数据应结构化为[*B, T, *F]
,其中B
是批次大小,T
是时间维度,F
是特征维度。tensordict 的形状必须为[*B, T]
。- 关键字参数:
params (TensorDictBase, optional) – A nested TensorDict containing the params to be passed to the functional value network module.
target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.
device (torch.device, optional) – 缓冲区将被实例化的设备。默认为
torch.get_default_device()
。
- 返回:
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
- value_estimate(tensordict, target_params: TensorDictBase | None = None, next_value: torch.Tensor | None = None, **kwargs)[源代码]¶
Gets a value estimate, usually used as a target value for the value network.
如果状态值键存在于
tensordict.get(("next", self.tensor_keys.value))
下,则将使用此值,而无需调用值网络。- 参数:
tensordict (TensorDictBase) – the tensordict containing the data to read.
target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.
next_value (torch.Tensor, optional) – 下一个状态或状态-动作对的值。与
target_params
互斥。**kwargs – the keyword arguments to be passed to the value network.
Returns: a tensor corresponding to the state value.