快捷方式

TD0Estimator

class torchrl.objectives.value.TD0Estimator(*args, **kwargs)[源代码]

优势函数的时序差分(TD(0))估计。

又称自举时序差分或 1 步回报。

关键字参数:
  • gamma (scalar) – exponential mean discount.

  • value_network (TensorDictModule) – 用于检索值估计的值运算符。

  • shifted (bool, optional) – 如果设置为 True,值和下一个值将通过对值网络的单次调用来估计。这更快,但仅在以下情况下有效:(1) "next" 值仅偏移一步(例如,对于多步值估计则不适用),并且 (2) 在时间 tt+1 使用的参数相同(在使用目标参数时则不适用)。默认为 False

  • average_rewards (bool, 可选) – 如果为 True,则在计算 TD 之前会对奖励进行标准化。

  • differentiable (bool, optional) –

    if True, gradients are propagated through the computation of the value function. Default is False.

    注意

    The proper way to make the function call non-differentiable is to decorate it in a torch.no_grad() context manager/decorator or pass detached parameters for functional modules.

  • skip_existing (bool, optional) – 如果设置为 True,值网络将跳过输出已存在于 tensordict 中的模块。默认为 None,即 tensordict.nn.skip_existing() 的值不受影响。

  • advantage_key (str or tuple of str, optional) – [Deprecated] the key of the advantage entry. Defaults to "advantage".

  • value_target_key (str or tuple of str, optional) – [已弃用] 优势项的键。默认为 "value_target"

  • value_key (str or tuple of str, optional) – [已弃用] 从输入 tensordict 读取的值键。默认为 "state_value"

  • device (torch.device, optional) – 缓冲区将被实例化的设备。默认为 torch.get_default_device()

  • deactivate_vmap (bool, 可选) – 是否禁用 vmap 调用并用普通 for 循环替换它们。默认为 False

forward(tensordict=None, *, params: TensorDictBase | None = None, target_params: TensorDictBase | None = None)[源代码]

在 tensordict 中给定数据,计算 TD(0) 优势。

If a functional module is provided, a nested TensorDict containing the parameters (and if relevant the target parameters) can be passed to the module.

参数:

tensordict (TensorDictBase) – 一个包含数据(观察键、"action"("next", "reward")("next", "done")("next", "terminated") 和由环境返回的 "next" tensordict 状态)的 TensorDict,这些数据对于计算值估计和 TDEstimate 是必需的。传递给此模块的数据应结构化为 [*B, T, *F],其中 B 是批次大小,T 是时间维度,F 是特征维度。tensordict 的形状必须是 [*B, T]

关键字参数:
  • params (TensorDictBase, optional) – A nested TensorDict containing the params to be passed to the functional value network module.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

返回:

An updated TensorDict with an advantage and a value_error keys as defined in the constructor.

示例

>>> from tensordict import TensorDict
>>> value_net = TensorDictModule(
...     nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
... )
>>> module = TDEstimate(
...     gamma=0.98,
...     value_network=value_net,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "terminated": terminated, "reward": reward}}, [1, 10])
>>> _ = module(tensordict)
>>> assert "advantage" in tensordict.keys()

The module supports non-tensordict (i.e. unpacked tensordict) inputs too

示例

>>> value_net = TensorDictModule(
...     nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
... )
>>> module = TDEstimate(
...     gamma=0.98,
...     value_network=value_net,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
value_estimate(tensordict, target_params: TensorDictBase | None = None, next_value: torch.Tensor | None = None, **kwargs)[源代码]

Gets a value estimate, usually used as a target value for the value network.

如果状态值键存在于 tensordict.get(("next", self.tensor_keys.value)) 下,则将使用此值,而无需调用值网络。

参数:
  • tensordict (TensorDictBase) – the tensordict containing the data to read.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

  • next_value (torch.Tensor, optional) – 下一个状态或状态-动作对的值。与 target_params 互斥。

  • **kwargs – the keyword arguments to be passed to the value network.

Returns: a tensor corresponding to the state value.

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源