快捷方式

GAE

class torchrl.objectives.value.GAE(*args, **kwargs)[源代码]

广义优势估计函数的类包装器。

Refer to “HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION” https://arxiv.org/pdf/1506.02438.pdf for more context.

参数:
  • gamma (scalar) – exponential mean discount.

  • lmbda (scalar) – trajectory discount.

  • value_network (TensorDictModule, optional) – 用于检索值估计的值运算符。如果为 None,此模块将期望 "state_value" 键已填充,并且不会调用值网络来生成它。

  • average_gae (bool) – 如果为 True,则结果的 GAE 值将进行标准化。默认为 False

  • differentiable (bool, optional) –

    if True, gradients are propagated through the computation of the value function. Default is False.

    注意

    The proper way to make the function call non-differentiable is to decorate it in a torch.no_grad() context manager/decorator or pass detached parameters for functional modules.

  • vectorized (bool, optional) – 是否使用 lambda 返回值的向量化版本。如果未编译,则默认为 True

  • skip_existing (bool, optional) – 如果为 True,则值网络将跳过其输出已存在于 tensordict 中的模块。默认为 None,即 tensordict.nn.skip_existing() 的值不受影响。默认为“state_value”。

  • advantage_key (str or tuple of str, optional) – [Deprecated] the key of the advantage entry. Defaults to "advantage".

  • value_target_key (str or tuple of str, optional) – [已弃用] 优势项的键。默认为 "value_target"

  • value_key (str or tuple of str, optional) – [已弃用] 从输入 tensordict 读取的值键。默认为 "state_value"

  • shifted (bool, optional) – 如果设置为 True,值和下一个值将通过对值网络的单次调用来估计。这更快,但仅在以下情况下有效:(1) "next" 值仅偏移一步(例如,对于多步值估计则不适用),并且 (2) 在时间 tt+1 使用的参数相同(在使用目标参数时则不适用)。默认为 False

  • device (torch.device, optional) – 缓冲区将被实例化的设备。默认为 torch.get_default_device()

  • time_dim (int, optional) – 输入 tensordict 中对应时间的维度。如果未提供,则默认为标记有 "time" 名称的维度(如果存在),否则默认为最后一个维度。可以在调用 value_estimate() 时覆盖。负维度相对于输入 tensordict 进行考虑。

  • auto_reset_env (bool, optional) – 如果为 True,则该回合的最后一个 "next" 状态无效,因此 GAE 计算将使用 value 而不是 next_value 来引导截断的回合。

  • deactivate_vmap (bool, optional) – 如果为 True,则不会使用 vmap 调用,向量化映射将替换为简单的 for 循环。默认为 False

GAE 将返回一个包含优势值的 "advantage" 条目。它还将返回一个 "value_target" 条目,其中包含用于训练值网络的返回。最后,如果 gradient_modeTrue,则将返回一个额外的、可微分的 "value_error" 条目,它简单地表示返回与值网络输出之间的差值(即,应将额外的距离损失应用于此带符号值)。

注意

与其他优势函数一样,如果 value_key 已存在于输入 tensordict 中,则 GAE 模块将忽略对值网络的调用(如果有),而是使用提供的值。

注意

GAE 可以与依赖于循环神经网络的值网络一起使用,前提是初始化标记(“is_init”)和终止/截断标记已正确设置。如果 shifted=True,则轨迹批次将被展平,并且每个轨迹的最后一步将放置在扁平 tensordict 中,位于根的最后一步之后,以便每个轨迹有 T+1 个元素。如果 shifted=False,则将堆叠根和 “next” 轨迹,并且值网络将使用 vmap 调用堆叠的轨迹。由于 RNN 需要相当多的控制流,因此它们目前不与 torch.vmap 兼容,因此在这些情况下必须启用 deactivate_vmap 选项。同样,如果 shifted=False,则根 tensordict 的 “is_init” 条目将被复制到 “next” 条目的 “is_init” 上,以便根和 “next” 数据都能得到良好的分隔。

forward(tensordict=None, *, params: list[Tensor] | None = None, target_params: list[Tensor] | None = None, time_dim: int | None = None)[源代码]

根据 tensordict 中的数据计算 GAE。

If a functional module is provided, a nested TensorDict containing the parameters (and if relevant the target parameters) can be passed to the module.

参数:

tensordict (TensorDictBase) – 一个 TensorDict,包含计算值估计和 GAE 所需的数据(一个观察键、"action"("next", "reward")("next", "done")("next", "terminated")"next" tensordict 状态,如环境返回)。传递给此模块的数据应结构化为 [*B, T, *F],其中 B 是批次大小,T 是时间维度,F 是特征维度。tensordict 的形状必须为 [*B, T]

关键字参数:
  • params (TensorDictBase, optional) – A nested TensorDict containing the params to be passed to the functional value network module.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

  • time_dim (int, optional) – 输入 tensordict 中对应时间的维度。如果未提供,则默认为标记有 "time" 名称的维度(如果存在),否则默认为最后一个维度。负维度相对于输入 tensordict 进行考虑。

返回:

An updated TensorDict with an advantage and a value_error keys as defined in the constructor.

示例

>>> from tensordict import TensorDict
>>> value_net = TensorDictModule(
...     nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
... )
>>> module = GAE(
...     gamma=0.98,
...     lmbda=0.94,
...     value_network=value_net,
...     differentiable=False,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward, "terminated": terminated}, [1, 10])
>>> _ = module(tensordict)
>>> assert "advantage" in tensordict.keys()

The module supports non-tensordict (i.e. unpacked tensordict) inputs too

示例

>>> value_net = TensorDictModule(
...     nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
... )
>>> module = GAE(
...     gamma=0.98,
...     lmbda=0.94,
...     value_network=value_net,
...     differentiable=False,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
value_estimate(tensordict, params: TensorDictBase | None = None, target_params: TensorDictBase | None = None, time_dim: int | None = None, **kwargs)[源代码]

Gets a value estimate, usually used as a target value for the value network.

如果状态值键存在于 tensordict.get(("next", self.tensor_keys.value)) 下,则将使用此值,而无需调用值网络。

参数:
  • tensordict (TensorDictBase) – the tensordict containing the data to read.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

  • next_value (torch.Tensor, optional) – 下一个状态或状态-动作对的值。与 target_params 互斥。

  • **kwargs – the keyword arguments to be passed to the value network.

Returns: a tensor corresponding to the state value.

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源