快捷方式

MultiStepActorWrapper

class torchrl.modules.tensordict_module.MultiStepActorWrapper(*args, **kwargs)[源代码]

一个围绕多步操作演员的包装器。

此类允许在环境中执行宏指令。演员的操作条目必须具有额外的维度才能被消费。它必须紧邻输入 tensordict 的最后一个维度(即在 tensordict.ndim 处)。

如果未提供操作条目键,则会自动从演员中检索,检索方式是一个简单的启发式方法(任何以 "action" 字符串结尾的嵌套键)。

输入 tensordict 中还必须存在一个 "is_init" 条目,用于跟踪当前集合应何时中断,因为遇到了“done”状态。与 action_keys 不同,此键必须是唯一的。

参数:
  • actor (TensorDictModuleBase) – 一个演员。

  • n_steps (int, optional) – 演员一次输出的操作数(前瞻窗口)。默认为 None

关键字参数:
  • action_keys (list of NestedKeys, optional) – 来自环境的操作键。可以从 env.action_keys 中检索。默认为 actor 的所有 out_keys,这些键以 "action" 字符串结尾。

  • init_key (NestedKey, optional) – 指示环境何时重置的条目的键。默认为 "is_init",这是 InitTracker 变换的 out_key

  • keep_dim (bool, optional) – 在索引期间是否保留宏指令的时间维度。默认为 False

示例

>>> import torch.nn
>>> from torchrl.modules.tensordict_module.actors import MultiStepActorWrapper, Actor
>>> from torchrl.envs import CatFrames, GymEnv, TransformedEnv, SerialEnv, InitTracker, Compose
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> time_steps = 6
>>> n_obs = 4
>>> n_action = 2
>>> batch = 5
>>>
>>> # Transforms a CatFrames in a stack of frames
>>> def reshape_cat(data: torch.Tensor):
...     return data.unflatten(-1, (time_steps, n_obs))
>>> # an actor that reads `time_steps` frames and outputs one action per frame
>>> # (actions are conditioned on the observation of `time_steps` in the past)
>>> actor_base = Seq(
...     Mod(reshape_cat, in_keys=["obs_cat"], out_keys=["obs_cat_reshape"]),
...     Mod(torch.nn.Linear(n_obs, n_action), in_keys=["obs_cat_reshape"], out_keys=["action"])
... )
>>> # Wrap the actor to dispatch the actions
>>> actor = MultiStepActorWrapper(actor_base, n_steps=time_steps)
>>>
>>> env = TransformedEnv(
...     SerialEnv(batch, lambda: GymEnv("CartPole-v1")),
...     Compose(
...         InitTracker(),
...         CatFrames(N=time_steps, in_keys=["observation"], out_keys=["obs_cat"], dim=-1)
...     )
... )
>>>
>>> print(env.rollout(100, policy=actor, break_when_any_done=False))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 100, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        action_orig: Tensor(shape=torch.Size([5, 100, 6, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        counter: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.int32, is_shared=False),
        done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False),
                observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([5, 100]),
            device=cpu,
            is_shared=False),
        obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([5, 100]),
    device=cpu,
    is_shared=False)

另请参阅

torchrl.envs.MultiStepEnvWrapper 是此包装器的 EnvBase 对应项:它包装一个环境并解绑操作,逐个元素执行。

forward(tensordict: TensorDictBase) TensorDictBase[源代码]

定义每次调用时执行的计算。

所有子类都应重写此方法。

注意

虽然前向传播的配方需要在此函数中定义,但之后应该调用 Module 实例而不是此函数,因为前者负责运行注册的钩子,而后者则默默地忽略它们。

property init_key: NestedKey

给定批次元素的初始步骤的指示器。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源