MultiStepActorWrapper¶
- class torchrl.modules.tensordict_module.MultiStepActorWrapper(*args, **kwargs)[源代码]¶
包装多步动作的 Actor。
此类允许在环境中执行宏观操作。Actor 的动作(actions)条目必须有一个额外的时间维度才能被消耗。它必须放置在输入 tensordict 的最后一个维度(即
tensordict.ndim
)的旁边。如果未提供,动作条目键将通过一个简单的启发式方法自动从 Actor 中检索(任何以
"action"
字符串结尾的嵌套键)。输入 tensordict 中还必须存在一个
"is_init"
条目,用于跟踪当前集合何时应该中断,因为遇到了“完成”状态。与action_keys
不同,此键必须是唯一的。- 参数:
actor (TensorDictModuleBase) – 一个 Actor。
n_steps (int, optional) – Actor 一次输出的动作数量(前瞻窗口)。默认为 None。
- 关键字参数:
action_keys (list of NestedKeys, optional) – 环境的动作键。可以从
env.action_keys
中检索。默认为actor
中以"action"
字符串结尾的所有out_keys
。init_key (NestedKey, optional) – 指示环境何时已重置的条目键。默认为
"is_init"
,这是InitTracker
变换的out_key
。keep_dim (bool, optional) – 在索引期间是否保留宏观操作的时间维度。默认为
False
。
示例
>>> import torch.nn >>> from torchrl.modules.tensordict_module.actors import MultiStepActorWrapper, Actor >>> from torchrl.envs import CatFrames, GymEnv, TransformedEnv, SerialEnv, InitTracker, Compose >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> >>> time_steps = 6 >>> n_obs = 4 >>> n_action = 2 >>> batch = 5 >>> >>> # Transforms a CatFrames in a stack of frames >>> def reshape_cat(data: torch.Tensor): ... return data.unflatten(-1, (time_steps, n_obs)) >>> # an actor that reads `time_steps` frames and outputs one action per frame >>> # (actions are conditioned on the observation of `time_steps` in the past) >>> actor_base = Seq( ... Mod(reshape_cat, in_keys=["obs_cat"], out_keys=["obs_cat_reshape"]), ... Mod(torch.nn.Linear(n_obs, n_action), in_keys=["obs_cat_reshape"], out_keys=["action"]) ... ) >>> # Wrap the actor to dispatch the actions >>> actor = MultiStepActorWrapper(actor_base, n_steps=time_steps) >>> >>> env = TransformedEnv( ... SerialEnv(batch, lambda: GymEnv("CartPole-v1")), ... Compose( ... InitTracker(), ... CatFrames(N=time_steps, in_keys=["observation"], out_keys=["obs_cat"], dim=-1) ... ) ... ) >>> >>> print(env.rollout(100, policy=actor, break_when_any_done=False)) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 100, 2]), device=cpu, dtype=torch.float32, is_shared=False), action_orig: Tensor(shape=torch.Size([5, 100, 6, 2]), device=cpu, dtype=torch.float32, is_shared=False), counter: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.int32, is_shared=False), done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5, 100]), device=cpu, is_shared=False), obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5, 100]), device=cpu, is_shared=False)
另请参阅
torchrl.envs.MultiStepEnvWrapper
是此包装器的 EnvBase 的对应对象:它包装一个环境并解绑动作,逐个元素执行。- forward(tensordict: TensorDictBase) TensorDictBase [源代码]¶
定义每次调用时执行的计算。
所有子类都应重写此方法。
注意
尽管前向传播的实现需要在此函数中定义,但您应该在之后调用
Module
实例而不是此函数,因为前者会处理注册的钩子,而后者则会静默忽略它们。
- property init_key: NestedKey¶
批次中给定元素的初始步骤指示器。