LSTMModule¶
- class torchrl.modules.LSTMModule(*args, **kwargs)[源代码]¶
LSTM 模块的嵌入器。
此类为
torch.nn.LSTM添加了以下功能:与 TensorDict 的兼容性:隐藏状态将被重塑以匹配 tensordict 的批处理大小。
可选的多步执行:使用 torch.nn,必须在
torch.nn.LSTMCell和torch.nn.LSTM之间进行选择,前者兼容单步输入,后者兼容多步输入。此类同时支持这两种用法。
构造后,该模块不处于循环模式,即它将期望单步输入。
如果处于循环模式,则 tensordict 的最后一个维度应标记为步数。tensordict 的维度没有限制(除非对于时间输入,维度必须大于一)。
注意
此类可以处理沿时间维度的多个连续轨迹,但是在这种情况下,不应信任最终隐藏值(即,不应将其重新用于连续轨迹)。原因是 LSTM 只返回最后一个隐藏值,对于我们提供的填充输入,该值可能对应于一个 0 填充的输入。
- 参数:
input_size – 输入 x 中预期特征的数量
hidden_size – 隐藏状态 h 中的特征数量
num_layers – 循环层数。例如,设置
num_layers=2意味着将两个 LSTM 堆叠在一起形成一个“堆叠 LSTM”,第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认为:1bias – 如果为
False,则该层不使用偏置权重 b_ih 和 b_hh。默认值:Truedropout – 如果非零,则在除最后一层外的每个 LSTM 层的输出上引入“Dropout”层,其 dropout 概率等于
dropout。默认为:0python_based — 如果为
True,将使用完整的 Python 实现的 LSTM 单元。默认值:False
- 关键字参数:
in_key (str 或 tuple of str) – 模块的输入键。与
in_keys互斥使用。如果提供,循环键假定为 [“recurrent_state_h”, “recurrent_state_c”],并且in_key将添加到它们之前。in_keys (list of str) – 一组三个字符串,分别对应输入值、第一个和第二个隐藏键。与
in_key互斥。out_key (str 或 tuple of str) – 模块的输出键。与
out_keys互斥使用。如果提供,循环键假定为 [(“next”, “recurrent_state_h”), (“next”, “recurrent_state_c”)],并且out_key将添加到它们之前。out_keys (list of str) –
一组三个字符串,分别对应输出值、第一个和第二个隐藏键。.. 注意
For a better integration with TorchRL's environments, the best naming for the output hidden key is ``("next", <custom_key>)``, such that the hidden values are passed from step to step during a rollout.device (torch.device 或 compatible) – 模块的设备。
lstm (torch.nn.LSTM, optional) – 要包装的 LSTM 实例。与其他 nn.LSTM 参数互斥。
default_recurrent_mode (bool, optional) – 如果提供,则为循环模式,如果尚未被
set_recurrent_mode上下文管理器/装饰器覆盖。默认为False。
- 变量:
recurrent_mode – 返回模块的循环模式。
注意
此模块依赖于输入 TensorDict 中存在的特定
recurrent_state键。要生成TensorDictPrimer转换,该转换将自动向环境 TensorDict 添加隐藏状态,请使用方法make_tensordict_primer()。如果此类是更大模块的子模块,则可以对父模块调用方法get_primers_from_module(),以自动生成此模块所需的所有 primer 转换。示例
>>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP, LSTMModule >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm_module = LSTMModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs_h", "rs_c"], ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) TensorDict( fields={ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ rs_c: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False), rs_h: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- forward(tensordict: TensorDictBase = None)[源代码]¶
定义每次调用时执行的计算。
所有子类都应重写此方法。
注意
尽管前向传播的实现需要在此函数中定义,但您应该在之后调用
Module实例而不是此函数,因为前者会处理注册的钩子,而后者则会静默忽略它们。
- make_cudnn_based() LSTMModule[源代码]¶
将 LSTM 层转换为其基于 CuDNN 的版本。
- 返回:
self
- make_python_based() LSTMModule[源代码]¶
将 LSTM 层转换为其基于 Python 的版本。
- 返回:
self
- make_tensordict_primer()[源代码]¶
为环境创建一个 tensordict primer。
一个
TensorDictPrimer对象将确保策略在 Rollout 执行期间了解补充输入和输出(循环状态)。这样,数据就可以在进程之间共享并得到妥善处理。使用批处理环境(如
ParallelEnv)时,该转换可以在单个环境实例级别(即,一组具有内部设置的 tensordict primers 的转换后的环境)或在批处理环境实例级别(即,一组普通环境的转换后的批处理)上使用。如果在环境中未包含
TensorDictPrimer,可能会导致行为不当,例如在并行设置中,一个步骤涉及将新的循环状态从"next"复制到根 tensordict,而~torchrl.EnvBase.step_mdp方法将无法执行此操作,因为循环状态未在环境规范中注册。有关生成给定模块所有 primer 的方法,请参阅
torchrl.modules.utils.get_primers_from_module()。示例
>>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP, LSTMModule >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm_module = LSTMModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs_h", "rs_c"], ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) >>> env = env.append_transform(lstm_module.make_tensordict_primer()) >>> data_collector = SyncDataCollector( ... env, ... policy, ... frames_per_batch=10 ... ) >>> for data in data_collector: ... print(data) ... break