快捷方式

TensorDictModule

作者Nicolas Dufour, Vincent Moens

在本教程中,您将学习如何使用 TensorDictModuleTensorDictSequential 创建通用且可重用的模块,这些模块可以接受 TensorDict 作为输入。

为了方便使用 TensorDict 类与 Moduletensordict 提供了两者之间的接口,名为 TensorDictModule

TensorDictModule 类是一个 Module,在调用时它接受一个 TensorDict 作为输入。它将读取一系列输入键,将它们作为输入传递给包装的模块或函数,并在执行完成后将输出写入同一个 tensordict 中。

由用户定义要读取的输入和输出键。

import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential

简单示例:编码循环层

TensorDictModule 的最简单用法如下例所示。如果一开始使用此类似乎引入了不必要的复杂性,我们稍后会看到,此 API 使用户能够以编程方式将模块连接在一起,在模块之间缓存值或以编程方式构建模块。其中一个最简单的例子是类似 ResNet 的架构中的循环模块,其中模块的输入被缓存并添加到小型多层感知器 (MLP) 的输出中。

首先,让我们考虑一下如何将 MLP 分块,并使用 tensordict.nn 进行编码。堆栈的第一层可能是一个 Linear 层,它接受一个条目(我们称之为 x)作为输入,并输出另一个条目(我们称之为 y)。

为了馈送我们的模块,我们有一个包含单个条目 "x"TensorDict 实例

tensordict = TensorDict(
    x=torch.randn(5, 3),
    batch_size=[5],
)

现在,我们使用 tensordict.nn.TensorDictModule 构建我们的简单模块。默认情况下,此类会在输入 tensordict 中原地写入(这意味着条目写入与输入相同的 tensordict,而不是原地覆盖条目!),因此我们无需显式指示输出是什么。

linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)

assert "linear0" in tensordict

如果模块输出多个张量(或 tensordict!),则必须按正确的顺序将它们的条目传递给 TensorDictModule

支持可调用对象

在设计模型时,经常需要将任意的非参数函数纳入网络。例如,您可能希望在图像传递到卷积网络或视觉 Transformer 时对其维度进行排列,或者将值除以 255。有几种方法可以做到这一点:您可以使用 forward_hook,或者设计一个新的执行此操作的 Module

TensorDictModule 可与任何可调用对象配合使用,而不仅仅是模块,这使得将任意函数集成到模块中变得容易。例如,让我们看看如何在不使用 ReLU 模块的情况下集成 relu 激活函数。

relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])

堆叠模块

我们的 MLP 不仅仅是一个层,所以我们需要再添加一个层。这一层将是一个激活函数,例如 ReLU。我们可以使用 TensorDictSequential 将此模块和前一个模块堆叠起来。

注意

这就是 tensordict.nn 的真正威力所在:与 Sequential 不同,TensorDictSequential 将会记住所有之前的输入和输出(并可以选择之后过滤掉它们),从而可以轻松地实时以编程方式构建复杂的网络结构。

block0 = TensorDictSequential(linear0, relu0)

block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict

我们可以重复这个逻辑来构建一个完整的 MLP。

linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)

多个输入键

残差网络的最后一步是将输入添加到最后一个线性层的输出。不需要为此编写特殊的 Module 子类! TensorDictModule 也可以用于包装简单的函数。

residual = TensorDictModule(
    lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)

现在我们可以将 block0block1residual 组合起来,构建一个完整的残差块。

block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict

一个真正的担忧可能是 tensordict 中条目的累积:在某些情况下(例如,当需要梯度时)中间值可能会被缓存,但这并非总是如此,而且让垃圾回收器知道某些条目可以被丢弃可能很有用。 tensordict.nn.TensorDictModuleBase 及其子类(包括 tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential)可以选择在执行后过滤其输出键。为此,只需调用 tensordict.nn.TensorDictModuleBase.select_out_keys 方法。这将原地更新模块,所有不需要的条目都将被丢弃。

block.select_out_keys("y")

tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict

assert "linear1" not in tensordict

然而,输入键被保留。

assert "x" in tensordict

顺便说一句,selected_out_keys 也可以传递给 tensordict.nn.TensorDictSequential,以避免单独调用此方法。

不使用 tensordict 的 TensorDictModule

tensordict.nn.TensorDictSequential 提供的即时构建复杂架构的机会,并不意味着必须切换到 tensordict 来表示数据。得益于 dispatchtensordict.nn 中的模块支持与条目名称匹配的参数和关键字参数。

x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)

在底层,dispatch 会重新构建一个 tensordict,运行模块,然后将其解构。这可能会导致一些开销,但正如我们接下来将看到的,有一个解决方案可以消除这一点。

运行时

tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential 在执行时会产生一些开销,因为它们需要读写 tensordict。但是,我们可以通过使用 compile() 来大大减少这种开销。为此,让我们比较一下带或不带 compile 的此代码的三种版本。

class ResidualBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear0 = nn.Linear(3, 128)
        self.relu0 = nn.ReLU()
        self.linear1 = nn.Linear(128, 128)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(128, 3)

    def forward(self, x):
        y = self.linear0(x)
        y = self.relu0(y)
        y = self.linear1(y)
        y = self.relu1(y)
        return self.linear2(y) + x


print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block

from torch.utils.benchmark import Timer

print(
    f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
    f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
    f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)

print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_notd_c(x)
print(
    f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tdm_c(x=x)
print(
    f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tds_c(x=x)
print(
    f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
Without compile
Regular:  215.3987 us
TDM:  280.3646 us
Sequential:  503.1584 us
Compiled versions
cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 200, in forward
    y = self.linear0(x)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 201, in forward
    y = self.relu0(y)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 202, in forward
    y = self.linear1(y)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 203, in forward
    y = self.relu1(y)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
    return self.linear2(y) + x

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
    return self.linear2(y) + x

Compiled regular:  374.3750 us
cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 200, in forward
    y = self.linear0(x)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 201, in forward
    y = self.relu0(y)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 202, in forward
    y = self.linear1(y)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 203, in forward
    y = self.relu1(y)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
    return self.linear2(y) + x

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
    return self.linear2(y) + x

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
    return self.linear2(y) + x

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
    return self.linear2(y) + x

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
    return self.linear2(y) + x

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 203, in forward
    y = self.relu1(y)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 202, in forward
    y = self.linear1(y)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 201, in forward
    y = self.relu0(y)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 202, in forward
    y = self.linear1(y)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 202, in forward
    y = self.linear1(y)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 200, in forward
    y = self.linear0(x)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 200, in forward
    y = self.linear0(x)

Compiled TDM:  408.7855 us
cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
    tensordict_exec = self._run_module(
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
    tensordict_exec = self._run_module(
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
    tensordict_exec = self._run_module(
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
    tensordict_exec = self._run_module(
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
    tensordict_exec = self._run_module(
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
    tensordict_exec = self._run_module(
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
    tensordict_exec = self._run_module(
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
    tensordict_exec = self._run_module(
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
    tensordict_exec = self._run_module(
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
    tensordict_exec = self._run_module(
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)

cudagraph partition due to non gpu ops. Found from :
   File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
    out = func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
    tensordict_exec = self._run_module(
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
    tensors_out = self._call_module(tensors, **kwargs)
  File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 131, in <lambda>
    lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]

Compiled sequential:  377.9500 us

如您所见,TensorDictSequential 引入的开销已完全消除。

使用 TensorDictModule 的注意事项

  • 不要在 tensordict.nn 中的模块周围使用 Sequence。这会破坏输入/输出键结构。始终尝试依赖 nn:TensorDictSequential

  • 不要将输出 tensordict 分配给新变量,因为输出 tensordict 只是原地修改的输入。分配新变量名并非严格禁止,但这可能意味着您希望当一个变量被删除时,两个变量都消失,而实际上垃圾回收器仍然可以看到工作空间中的张量,并且不会释放任何内存。

    >>> tensordict = module(tensordict)  # ok!
    >>> tensordict_out = module(tensordict)  # don't!
    

处理分布:ProbabilisticTensorDictModule

ProbabilisticTensorDictModule 是一个表示概率分布的非参数模块。分布参数从 tensordict 输入中读取,输出写入输出 tensordict。根据由输入 default_interaction_type 参数和 interaction_type() 全局函数指定的规则,给定这些参数进行采样。如果它们发生冲突,上下文管理器优先。

它可以与返回已更新分布参数的 tensordict 的 TensorDictModule 一起使用,通过 ProbabilisticTensorDictSequential 进行连接。这是 TensorDictSequential 的一个特例,其最后一层是 ProbabilisticTensorDictModule 实例。

ProbabilisticTensorDictModule 负责构建分布(通过 get_dist() 方法)和/或从该分布采样(通过模块的常规 forward 调用)。相同的 get_dist() 方法在 ProbabilisticTensorDictSequential 中公开。

可以在输出 tensordict 中找到参数,如果需要,还可以找到对数概率。

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist

td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    net,
    extractor,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=dist.Normal,
        return_log_prob=True,
    ),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        action_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

结论

我们已经看到了 tensordict.nn 如何用于动态地实时构建复杂的神经网络架构。这为构建模型签名无关的管道提供了可能性,即编写通用的代码,以灵活的方式使用具有任意数量输入或输出的网络。

我们还看到了 dispatch 如何使 tensordict.nn 能够构建此类网络并无需依赖 TensorDict 直接使用它们。得益于 compile()tensordict.nn.TensorDictSequential 引入的开销可以完全消除,使用户获得一个整洁、无 tensordict 的模块版本。

在下一个教程中,我们将看到 torch.export 如何用于隔离和导出模块。

脚本总运行时间: (0 分钟 13.884 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源