快捷方式

TensorDictModule

作者Nicolas Dufour, Vincent Moens

在本教程中,您将学习如何使用 TensorDictModuleTensorDictSequential 来创建可以接受 TensorDict 作为输入的通用且可重用的模块。

为了方便将 TensorDict 类与 Module 结合使用,tensordict 提供了它们之间的接口,名为 TensorDictModule

TensorDictModule 类是一个 Module,在被调用时接受一个 TensorDict 作为输入。它将读取一系列输入键,将它们作为输入传递给包装的模块或函数,并在完成执行后将输出写入同一个 tensordict 中。

由用户来定义要读取的输入键和输出键。

import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential

简单示例:编码一个循环层

下面例举了 TensorDictModule 最简单的用法。如果一开始您觉得使用此类会引入不必要的复杂性,我们稍后将看到,此 API 使用户能够以编程方式将模块串联起来、在模块之间缓存值或以编程方式构建模块。其中一个最简单的例子是像 ResNet 这样的架构中的循环模块,其中模块的输入被缓存并添加到小型多层感知机 (MLP) 的输出中。

首先,让我们考虑如何将一个 MLP 分块,并使用 tensordict.nn 来编码它。堆栈的第一层可能是一个 Linear 层,它接受一个条目(我们称之为 x)作为输入,并输出另一个条目(我们称之为 y)。

为了馈送给我们的模块,我们有一个 TensorDict 实例,其中包含一个条目 "x"

tensordict = TensorDict(
    x=torch.randn(5, 3),
    batch_size=[5],
)

现在,我们使用 tensordict.nn.TensorDictModule 构建我们的简单模块。默认情况下,此类会就地写入输入 tensordict(这意味着条目会写入与输入相同的 tensordict,而不是就地覆盖条目!),因此我们无需显式指示输出是什么。

linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)

assert "linear0" in tensordict

如果模块输出多个张量(或 tensordicts!),则必须按正确的顺序将它们的条目传递给 TensorDictModule

对可调用对象的支持

在设计模型时,您经常需要将任意非参数函数纳入网络。例如,您可能希望在图像传递到卷积网络或视觉 transformer 时对其维度进行置换,或者将值除以 255。有几种方法可以做到这一点:例如,您可以使用 forward_hook,或者设计一个新的 Module 来执行此操作。

TensorDictModule 可与任何可调用对象(不仅仅是模块)配合使用,这使其易于将任意函数集成到模块中。例如,让我们看看如何在不使用 ReLU 模块的情况下集成 relu 激活函数。

relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])

堆叠模块

我们的 MLP 不止一个层,所以我们需要再添加一层。这一层将是激活函数,例如 ReLU。我们可以使用 TensorDictSequential 将此模块和之前的模块堆叠起来。

注意

tensordict.nn 的真正强大之处在于:与 Sequential 不同,TensorDictSequential 将在内存中保留所有先前的输入和输出(并且可以事后过滤掉它们),这使得轻松地即时以编程方式构建复杂的网络结构成为可能。

block0 = TensorDictSequential(linear0, relu0)

block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict

我们可以重复此逻辑来获得完整的 MLP。

linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)

多个输入键

残差网络的最后一步是将输入加到最后一个线性层的输出上。为此无需编写特殊的 Module 子类! TensorDictModule 也可以包装简单函数。

residual = TensorDictModule(
    lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)

现在,我们可以将 block0block1residual 组合起来,构建一个完整的残差块。

block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict

一个真正的顾虑可能是输入 tensordict 中条目的累积:在某些情况下(例如,当需要梯度时),中间值会被缓存,但这并不总是如此,而且让垃圾回收器知道某些条目可以被丢弃可能会很有用。 tensordict.nn.TensorDictModuleBase 及其子类(包括 tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential)可以选择在执行后过滤掉它们的输出键。为此,只需调用 tensordict.nn.TensorDictModuleBase.select_out_keys 方法。这将就地更新模块,并且所有不需要的条目都将被丢弃。

block.select_out_keys("y")

tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict

assert "linear1" not in tensordict

但是,输入键会被保留。

assert "x" in tensordict

顺便说一句,selected_out_keys 也可以传递给 tensordict.nn.TensorDictSequential,以避免单独调用此方法。

未使用 tensordict 的 TensorDictModule

tensordict.nn.TensorDictSequential 提供的即时构建复杂架构的机会,并不意味着一个人必须切换到 tensordict 来表示数据。得益于 dispatchtensordict.nn 中的模块支持与条目名称匹配的参数和关键字参数。

x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)

在底层,dispatch 会重建一个 tensordict,运行模块,然后将其解构。这可能会产生一些开销,但正如我们接下来将看到的,有一个解决方案可以消除这一点。

运行时

tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential 在执行时会产生一些开销,因为它们需要从 tensordict 读取和写入。但是,我们可以通过使用 compile() 来大大减少这种开销。为此,让我们比较一下此代码在启用和禁用编译的三种版本。

class ResidualBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear0 = nn.Linear(3, 128)
        self.relu0 = nn.ReLU()
        self.linear1 = nn.Linear(128, 128)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(128, 3)

    def forward(self, x):
        y = self.linear0(x)
        y = self.relu0(y)
        y = self.linear1(y)
        y = self.relu1(y)
        return self.linear2(y) + x


print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block

from torch.utils.benchmark import Timer

print(
    f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
    f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
    f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)

print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_notd_c(x)
print(
    f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tdm_c(x=x)
print(
    f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tds_c(x=x)
print(
    f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
Without compile
Regular:  219.9135 us
TDM:  275.8492 us
Sequential:  486.7035 us
Compiled versions
Compiled regular:  308.4260 us
Compiled TDM:  376.0630 us
Compiled sequential:  349.8425 us

正如您所见,TensorDictSequential 引入的开销已完全消除。

TensorDictModule 的注意事项

  • 不要在 tensordict.nn 的模块周围使用 Sequence。这会破坏输入/输出键结构。始终尝试改用 nn:TensorDictSequential

  • 不要将输出 tensordict 分配给新变量,因为输出 tensordict 只是原地修改的输入。分配新变量名称并非严格禁止,但这意味着您可能希望在其中一个被删除时两者都消失,而实际上垃圾回收器仍然可以看到工作空间中的张量,并且不会释放内存。

    >>> tensordict = module(tensordict)  # ok!
    >>> tensordict_out = module(tensordict)  # don't!
    

处理分布: ProbabilisticTensorDictModule

ProbabilisticTensorDictModule 是一个表示概率分布的非参数模块。分布参数从 tensordict 输入读取,输出写入输出 tensordict。输出根据规则采样,该规则由输入 default_interaction_type 参数和 interaction_type() 全局函数指定。如果它们冲突,则上下文管理器优先。

它可以与返回更新了分布参数的 tensordict 的 TensorDictModule 一起使用,通过 ProbabilisticTensorDictSequential 进行连接。这是 TensorDictSequential 的一种特殊情况,其最后一层是 ProbabilisticTensorDictModule 实例。

ProbabilisticTensorDictModule 负责构造分布(通过 get_dist() 方法)和/或从该分布采样(通过对模块的常规 forward 调用)。相同的 get_dist() 方法在 ProbabilisticTensorDictSequential 中公开。

可以在输出 tensordict 中找到参数,如果需要,也可以找到对数概率。

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist

td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    net,
    extractor,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=dist.Normal,
        return_log_prob=True,
    ),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        action_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

结论

我们已经看到 tensordict.nn 如何用于动态地即时构建复杂的神经网络架构。这为构建对模型签名不敏感的管道提供了可能性,即编写通用代码,以灵活的方式使用具有任意数量输入或输出的网络。

我们还看到了 dispatch 如何能够使用 tensordict.nn 来构建此类网络,并在不直接依赖 TensorDict 的情况下使用它们。得益于 compile()tensordict.nn.TensorDictSequential 引入的开销可以完全消除,为用户留下一个整洁、无 tensordict 的模块版本。

在下一教程中,我们将看到 torch.export 如何用于隔离和导出模块。

脚本总运行时间: (0 分钟 13.134 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源