step_mdp¶
- torchrl.envs.step_mdp(tensordict: TensorDictBase, next_tensordict: TensorDictBase = None, keep_other: bool = True, exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, reward_keys: NestedKey | list[NestedKey] = 'reward', done_keys: NestedKey | list[NestedKey] = 'done', action_keys: NestedKey | list[NestedKey] = 'action') TensorDictBase [源代码]¶
创建一个新的 tensordict,以反映输入 tensordict 的一个时间步长。
给定一个在执行一个步长后检索到的 tensordict,返回“next”索引的 tensordict。这些参数允许精确控制应该保留什么以及应该从“next”条目中复制什么。默认行为是:将 observation 条目、reward 和 done 状态移动到根目录,排除当前 action,并保留所有额外的键(非 action、非 done、非 reward)。
- 参数:
tensordict (TensorDictBase) – 需要重命名键的 tensordict。
next_tensordict (TensorDictBase, 可选) – 目标 tensordict。如果为 None,则创建一个新的 tensordict。
keep_other (bool, 可选) – 如果为
True
,则所有不以'next_'
开头的键都将被保留。默认为True
。exclude_reward (bool, 可选) – 如果为
True
,则"reward"
键将从结果 tensordict 中丢弃。如果为False
,它将从“next”条目中复制(并替换)(如果存在)。默认为True
。exclude_done (bool, 可选) – 如果为
True
,则"done"
键将从结果 tensordict 中丢弃。如果为False
,它将从“next”条目中复制(并替换)(如果存在)。默认为False
。exclude_action (bool, 可选) – 如果为
True
,则"action"
键将从结果 tensordict 中丢弃。如果为False
,它将保留在根 tensordict 中(因为它不应该出现在“next”条目中)。默认为True
。reward_keys (NestedKey 或 NestedKey 列表, 可选) – 写入 reward 的键。默认为“reward”。
done_keys (NestedKey 或 NestedKey 列表, 可选) – 写入 done 的键。默认为“done”。
action_keys (NestedKey 或 NestedKey 列表, 可选) – 写入 action 的键。默认为“action”。
- 返回:
包含 t+1 步张量的新的 tensordict(或如果提供了 next_tensordict,则为 next_tensordict)。
- 返回类型:
TensorDictBase
另请参阅
EnvBase.step_mdp()
是此自由函数的类版本。它将尝试缓存键值,以减少 MDP 中执行步长的开销。示例
>>> from tensordict import TensorDict >>> import torch >>> td = TensorDict({ ... "done": torch.zeros((), dtype=torch.bool), ... "reward": torch.zeros(()), ... "extra": torch.zeros(()), ... "next": TensorDict({ ... "done": torch.zeros((), dtype=torch.bool), ... "reward": torch.zeros(()), ... "obs": torch.zeros(()), ... }, []), ... "obs": torch.zeros(()), ... "action": torch.zeros(()), ... }, []) >>> print(step_mdp(td)) TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_done=True)) # "done" is dropped TensorDict( fields={ extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_reward=False)) # "reward" is kept TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_action=False)) # "action" persists at the root TensorDict( fields={ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, keep_other=False)) # "extra" is missing TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
警告
如果 reward 键也是输入键的一部分,并且 reward 键被排除,此函数将无法正常工作。这就是为什么
RewardSum
转换默认将剧集 reward 注册到 observation 中而不是 reward spec。当使用此函数的快速缓存版本(_StepMDP
)时,不应出现此问题。