快捷方式

WrapModule

class tensordict.nn.WrapModule(*args, **kwargs)

一个包装任何处理 TensorDict 实例的可调用对象的包装器。

当构建 TensorDictSequential 堆栈以及转换需要整个 TensorDict 实例可见时,此包装器非常有用。

参数:

func (Callable[[TensorDictBase], TensorDictBase]) – 一个可调用函数,它接收一个 TensorDictBase 实例并返回一个转换后的 TensorDictBase 实例。

关键字参数:
  • inplace (bool, optional) – 如果为 True,则输入 TensorDict 将被就地修改。否则,将返回一个新的 TensorDict(如果函数不就地修改并返回它)。默认为 False

  • in_keys (list of NestedKey, optional) – 如果提供,则指示模块读取哪些条目。这不会被检查,仅用于向 TensorDictSequential 提供有关被包装模块输入键的信息。默认为 []

  • out_keys (list of NestedKey, optional) – 如果提供,则指示模块写入哪些条目。这不会被检查,仅用于向 TensorDictSequential 提供有关被包装模块输出键的信息。默认为 []

示例

>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule
>>> seq = Seq(
...     Mod(lambda x: x * 2, in_keys=["x"], out_keys=["y"]),
...     WrapModule(lambda td: td.reshape(-1)),
... )
>>> td = TensorDict(x=torch.ones(3, 4, 5), batch_size=[3, 4])
>>> td = Seq(td)
>>> assert td.shape == (12,)
>>> assert (td["y"] == 2).all()
>>> assert td["y"].shape == (12, 5)
forward(data: TensorDictBase) TensorDictBase

定义每次调用时执行的计算。

所有子类都应重写此方法。

注意

虽然前向传递的实现需要在此函数中定义,但之后应调用 Module 实例而不是此函数,因为前者负责运行已注册的钩子,而后者会静默地忽略它们。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源