快捷方式

LSTMModule

class torchrl.modules.LSTMModule(*args, **kwargs)[源代码]

一个 LSTM 模块的嵌入器。

此类向 torch.nn.LSTM 添加了以下功能:

  • 与 TensorDict 的兼容性:隐藏状态被重塑以匹配 tensordict 的批处理大小。

  • 可选的多步执行:使用 torch.nn,必须在 torch.nn.LSTMCelltorch.nn.LSTM 之间进行选择,前者与单步输入兼容,后者与多步兼容。此类支持这两种用法。

构造完成后,模块**未**设置为循环模式,即它将期望单步输入。

如果处于循环模式,则 tensordict 的最后一个维度被假定为标记步数。对 tensordict 的维度没有限制(除了对于时序输入必须大于一个)。

注意

此类可以处理时间维度上的多个连续轨迹,**但是**在这些情况下,最终隐藏值不应被信任(即不应将其重新用于连续轨迹)。原因是 LSTM 只返回最后一个隐藏值,对于我们提供的填充输入,这可能对应于一个 0 填充的输入。

参数:
  • input_size – 输入 x 中预期特征的数量

  • hidden_size – 隐藏状态 h 中的特征数量

  • num_layers – 循环层数。例如,将 num_layers=2 设置为将两个 LSTM 堆叠在一起形成一个“堆叠 LSTM”,第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认为 1。

  • bias – 如果为 False,则该层不使用偏置权重 b_ihb_hh。默认为 True

  • dropout – 如果非零,则在除最后一层外的每一 LSTM 层的输出上引入一个“Dropout”层,其 dropout 概率等于 dropout。默认为 0。

  • python_based – 如果为 True,将使用完整的 Python 实现的 LSTM 单元。默认为 False

关键字参数:
  • in_key (strtuple of str) – 模块的输入键。与 in_keys 互斥使用。如果提供了 in_key,则假定循环键为 [“recurrent_state_h”, “recurrent_state_c”],并将 in_key 附加到它们之前。

  • in_keys (list of str) – 一个字符串三元组,对应于输入值、第一个和第二个隐藏键。与 in_key 互斥。

  • out_key (strtuple of str) – 模块的输出键。与 out_keys 互斥使用。如果提供了 out_key,则假定循环键为 [(“next”, “recurrent_state_h”), (“next”, “recurrent_state_c”)],并将 out_key 附加到它们之前。

  • out_keys (list of str) –

    一个字符串三元组,对应于输出值、第一个和第二个隐藏键。.. note

    For a better integration with TorchRL's environments, the best naming
    for the output hidden key is ``("next", <custom_key>)``, such
    that the hidden values are passed from step to step during a rollout.
    

  • device (torch.device兼容设备) – 模块的设备。

  • lstm (torch.nn.LSTM, 可选) – 要包装的 LSTM 实例。与其他 nn.LSTM 参数互斥。

  • default_recurrent_mode (bool, 可选) – 如果提供,则为循环模式,如果尚未被 set_recurrent_mode 上下文管理器/装饰器覆盖。默认为 False

变量:

recurrent_mode – 返回模块的循环模式。

set_recurrent_mode()[源代码]

控制模块是否应以循环模式执行。

make_tensordict_primer()[源代码]

为环境创建 TensorDictPrimer 转换,使其能够感知 RNN 的循环状态。

注意

此模块依赖于输入 TensorDict 中存在的特定 recurrent_state 键。要生成一个 TensorDictPrimer 转换,该转换将在回放执行期间自动将隐藏状态添加到环境 TensorDict 中,请使用方法 make_tensordict_primer()。如果此类是大型模块中的子模块,则可以对父模块调用方法 get_primers_from_module() 来自动生成包括此模块在内的所有子模块所需的 primer 转换。

示例

>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> lstm_module = LSTMModule(
...     input_size=env.observation_spec["observation"].shape[-1],
...     hidden_size=64,
...     in_keys=["observation", "rs_h", "rs_c"],
...     out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                rs_c: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                rs_h: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
forward(tensordict: TensorDictBase = None)[源代码]

定义每次调用时执行的计算。

所有子类都应重写此方法。

注意

虽然前向传播的实现需要在该函数内定义,但之后应调用“Module”实例而不是此函数,因为前者负责运行已注册的钩子,而后者则会默默地忽略它们。

make_cudnn_based() LSTMModule[源代码]

将 LSTM 层转换为其基于 CuDNN 的版本。

返回:

self

make_python_based() LSTMModule[源代码]

将 LSTM 层转换为其基于 Python 的版本。

返回:

self

make_tensordict_primer()[源代码]

为环境创建一个 tensordict primer。

一个 TensorDictPrimer 对象将确保策略在回放执行期间能够感知补充输入和输出(循环状态)。这样,数据就可以在进程之间共享并得到妥善处理。

当使用批处理环境(如 ParallelEnv)时,该转换可以在单个环境实例级别(即,一批转换后的环境,其内部设置了 tensordict primer)或在批处理环境实例级别(即,转换后的常规环境批次)使用。

不在环境中包含 TensorDictPrimer 可能会导致行为不当,例如在并行环境中,一个步进涉及将新的循环状态从 "next" 复制到根 tensordict,而 ~torchrl.EnvBase.step_mdp 方法将无法执行此操作,因为循环状态未在环境规范中注册。

有关生成给定模块所有 primer 的方法,请参阅 torchrl.modules.utils.get_primers_from_module()

示例

>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP, LSTMModule
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> lstm_module = LSTMModule(
...     input_size=env.observation_spec["observation"].shape[-1],
...     hidden_size=64,
...     in_keys=["observation", "rs_h", "rs_c"],
...     out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
>>> env = env.append_transform(lstm_module.make_tensordict_primer())
>>> data_collector = SyncDataCollector(
...     env,
...     policy,
...     frames_per_batch=10
... )
>>> for data in data_collector:
...     print(data)
...     break

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源