快捷方式

step_mdp

torchrl.envs.step_mdp(tensordict: TensorDictBase, next_tensordict: TensorDictBase = None, keep_other: bool = True, exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, reward_keys: NestedKey | list[NestedKey] = 'reward', done_keys: NestedKey | list[NestedKey] = 'done', action_keys: NestedKey | list[NestedKey] = 'action') TensorDictBase[源代码]

创建一个新的 tensordict,反映输入 tensordict 的时间步。

给定一个在步进后检索到的 tensordict,返回 "next" 索引的 tensordict。参数允许精确控制哪些内容应该被保留,哪些内容应该从 "next" 条目中复制。默认行为是:将 observation 条目、奖励和 done 状态移动到根目录,排除当前 action,并保留所有额外的键(非 action、非 done、非 reward)。

参数:
  • tensordict (TensorDictBase) – 要重命名的键的 tensordict。

  • next_tensordict (TensorDictBase, 可选) – 目标 tensordict。如果为 None,则创建一个新的 tensordict。

  • keep_other (bool, 可选) – 如果为 True,则会保留所有不以 'next_' 开头的键。默认为 True

  • exclude_reward (bool, 可选) – 如果为 True,则 "reward" 键将被从结果 tensordict 中丢弃。如果为 False,它将被从 "next" 条目(如果存在)复制(并替换)。默认为 True

  • exclude_done (bool, 可选) – 如果为 True,则 "done" 键将被从结果 tensordict 中丢弃。如果为 False,它将被从 "next" 条目(如果存在)复制(并替换)。默认为 False

  • exclude_action (bool, 可选) – 如果为 True,则 "action" 键将被从结果 tensordict 中丢弃。如果为 False,它将被保留在根 tensordict 中(因为它不应出现在 "next" 条目中)。默认为 True

  • reward_keys (NestedKeyNestedKey 列表, 可选) – 写入奖励的键。默认为“reward”。

  • done_keys (NestedKeyNestedKey 列表, 可选) – 写入 done 的键。默认为“done”。

  • action_keys (NestedKeyNestedKey 列表, 可选) – 写入 action 的键。默认为“action”。

返回:

包含 t+1 步张量的新的 tensordict(或如果提供了 next_tensordict 则为 next_tensordict)。

返回类型:

TensorDictBase

另请参阅

EnvBase.step_mdp() 是此自由函数的基于类的版本。它将尝试缓存键值以减少 MDP 步进的开销。

示例

>>> from tensordict import TensorDict
>>> import torch
>>> td = TensorDict({
...     "done": torch.zeros((), dtype=torch.bool),
...     "reward": torch.zeros(()),
...     "extra": torch.zeros(()),
...     "next": TensorDict({
...         "done": torch.zeros((), dtype=torch.bool),
...         "reward": torch.zeros(()),
...         "obs": torch.zeros(()),
...     }, []),
...     "obs": torch.zeros(()),
...     "action": torch.zeros(()),
... }, [])
>>> print(step_mdp(td))
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, exclude_done=True))  # "done" is dropped
TensorDict(
    fields={
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, exclude_reward=False))  # "reward" is kept
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, exclude_action=False))  # "action" persists at the root
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, keep_other=False))  # "extra" is missing
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

警告

如果奖励键也包含在输入键中(当排除奖励键时),此函数将无法正常工作。这就是为什么 RewardSum 转换默认将剧集奖励注册到 observation 而不是 reward spec。当使用此函数的快速缓存版本(_StepMDP)时,不应观察到此问题。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源