快捷方式

ValueEstimatorBase

class torchrl.objectives.value.ValueEstimatorBase(*args, **kwargs)[源代码]

价值函数模块的抽象父类。

ValueFunctionBase.forward() 方法将计算值(由价值网络给出)和价值估计(由价值估计器给出)以及优势,并将这些值写入输出 tensordict。

如果只需要价值估计,则应改用 ValueFunctionBase.value_estimate()

default_keys

别名:_AcceptedKeys

abstract forward(tensordict: TensorDictBase, *, params: TensorDictBase | None = None, target_params: TensorDictBase | None = None) TensorDictBase[源代码]

给定 tensordict 中的数据,计算优势估计。

If a functional module is provided, a nested TensorDict containing the parameters (and if relevant the target parameters) can be passed to the module.

参数:

tensordict (TensorDictBase) – 一个包含数据的 TensorDict(一个观测键,"action"("next", "reward")("next", "done")("next", "terminated"),以及由环境返回的 "next" tensordict 状态),这些数据用于计算价值估计和 TD 估计。传递给此模块的数据应结构化为 [*B, T, *F],其中 B 是批次大小,T 是时间维度,F 是特征维度。tensordict 的形状必须为 [*B, T]

关键字参数:
  • params (TensorDictBase, optional) – A nested TensorDict containing the params to be passed to the functional value network module.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

  • device (torch.device, optional) – 缓冲区将被实例化的设备。默认为 torch.get_default_device()

返回:

An updated TensorDict with an advantage and a value_error keys as defined in the constructor.

set_keys(**kwargs) None[源代码]

设置 tensordict 键名。

value_estimate(tensordict, target_params: TensorDictBase | None = None, next_value: torch.Tensor | None = None, **kwargs)[源代码]

Gets a value estimate, usually used as a target value for the value network.

如果状态值键存在于 tensordict.get(("next", self.tensor_keys.value)) 下,则将使用此值,而无需调用值网络。

参数:
  • tensordict (TensorDictBase) – the tensordict containing the data to read.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

  • next_value (torch.Tensor, optional) – 下一个状态或状态-动作对的值。与 target_params 互斥。

  • **kwargs – the keyword arguments to be passed to the value network.

Returns: a tensor corresponding to the state value.

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源