注意
转到末尾 下载完整的示例代码。
TensorDictModule¶
作者: Nicolas Dufour, Vincent Moens
在本教程中,您将学习如何使用 TensorDictModule
和 TensorDictSequential
来创建可以接受 TensorDict
作为输入的通用且可重用的模块。
为了方便将 TensorDict
类与 Module
结合使用,tensordict
提供了它们之间的接口,名为 TensorDictModule
。
TensorDictModule
类是一个 Module
,在被调用时接受一个 TensorDict
作为输入。它将读取一系列输入键,将它们作为输入传递给包装的模块或函数,并在完成执行后将输出写入同一个 tensordict 中。
由用户来定义要读取的输入键和输出键。
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
简单示例:编码一个循环层¶
下面例举了 TensorDictModule
最简单的用法。如果一开始您觉得使用此类会引入不必要的复杂性,我们稍后将看到,此 API 使用户能够以编程方式将模块串联起来、在模块之间缓存值或以编程方式构建模块。其中一个最简单的例子是像 ResNet 这样的架构中的循环模块,其中模块的输入被缓存并添加到小型多层感知机 (MLP) 的输出中。
首先,让我们考虑如何将一个 MLP 分块,并使用 tensordict.nn
来编码它。堆栈的第一层可能是一个 Linear
层,它接受一个条目(我们称之为 x)作为输入,并输出另一个条目(我们称之为 y)。
为了馈送给我们的模块,我们有一个 TensorDict
实例,其中包含一个条目 "x"
。
tensordict = TensorDict(
x=torch.randn(5, 3),
batch_size=[5],
)
现在,我们使用 tensordict.nn.TensorDictModule
构建我们的简单模块。默认情况下,此类会就地写入输入 tensordict(这意味着条目会写入与输入相同的 tensordict,而不是就地覆盖条目!),因此我们无需显式指示输出是什么。
linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)
assert "linear0" in tensordict
如果模块输出多个张量(或 tensordicts!),则必须按正确的顺序将它们的条目传递给 TensorDictModule
。
对可调用对象的支持¶
在设计模型时,您经常需要将任意非参数函数纳入网络。例如,您可能希望在图像传递到卷积网络或视觉 transformer 时对其维度进行置换,或者将值除以 255。有几种方法可以做到这一点:例如,您可以使用 forward_hook,或者设计一个新的 Module
来执行此操作。
TensorDictModule
可与任何可调用对象(不仅仅是模块)配合使用,这使其易于将任意函数集成到模块中。例如,让我们看看如何在不使用 ReLU
模块的情况下集成 relu
激活函数。
relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])
堆叠模块¶
我们的 MLP 不止一个层,所以我们需要再添加一层。这一层将是激活函数,例如 ReLU
。我们可以使用 TensorDictSequential
将此模块和之前的模块堆叠起来。
注意
tensordict.nn
的真正强大之处在于:与 Sequential
不同,TensorDictSequential
将在内存中保留所有先前的输入和输出(并且可以事后过滤掉它们),这使得轻松地即时以编程方式构建复杂的网络结构成为可能。
block0 = TensorDictSequential(linear0, relu0)
block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict
我们可以重复此逻辑来获得完整的 MLP。
linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)
多个输入键¶
残差网络的最后一步是将输入加到最后一个线性层的输出上。为此无需编写特殊的 Module
子类! TensorDictModule
也可以包装简单函数。
residual = TensorDictModule(
lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)
现在,我们可以将 block0
、block1
和 residual
组合起来,构建一个完整的残差块。
block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict
一个真正的顾虑可能是输入 tensordict 中条目的累积:在某些情况下(例如,当需要梯度时),中间值会被缓存,但这并不总是如此,而且让垃圾回收器知道某些条目可以被丢弃可能会很有用。 tensordict.nn.TensorDictModuleBase
及其子类(包括 tensordict.nn.TensorDictModule
和 tensordict.nn.TensorDictSequential
)可以选择在执行后过滤掉它们的输出键。为此,只需调用 tensordict.nn.TensorDictModuleBase.select_out_keys
方法。这将就地更新模块,并且所有不需要的条目都将被丢弃。
block.select_out_keys("y")
tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict
assert "linear1" not in tensordict
但是,输入键会被保留。
assert "x" in tensordict
顺便说一句,selected_out_keys
也可以传递给 tensordict.nn.TensorDictSequential
,以避免单独调用此方法。
未使用 tensordict 的 TensorDictModule¶
tensordict.nn.TensorDictSequential
提供的即时构建复杂架构的机会,并不意味着一个人必须切换到 tensordict 来表示数据。得益于 dispatch
,tensordict.nn 中的模块支持与条目名称匹配的参数和关键字参数。
x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)
在底层,dispatch
会重建一个 tensordict,运行模块,然后将其解构。这可能会产生一些开销,但正如我们接下来将看到的,有一个解决方案可以消除这一点。
运行时¶
tensordict.nn.TensorDictModule
和 tensordict.nn.TensorDictSequential
在执行时会产生一些开销,因为它们需要从 tensordict 读取和写入。但是,我们可以通过使用 compile()
来大大减少这种开销。为此,让我们比较一下此代码在启用和禁用编译的三种版本。
class ResidualBlock(nn.Module):
def __init__(self):
super().__init__()
self.linear0 = nn.Linear(3, 128)
self.relu0 = nn.ReLU()
self.linear1 = nn.Linear(128, 128)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(128, 3)
def forward(self, x):
y = self.linear0(x)
y = self.relu0(y)
y = self.linear1(y)
y = self.relu1(y)
return self.linear2(y) + x
print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block
from torch.utils.benchmark import Timer
print(
f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5): # warmup
block_notd_c(x)
print(
f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5): # warmup
block_tdm_c(x=x)
print(
f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5): # warmup
block_tds_c(x=x)
print(
f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
Without compile
Regular: 219.9135 us
TDM: 275.8492 us
Sequential: 486.7035 us
Compiled versions
Compiled regular: 308.4260 us
Compiled TDM: 376.0630 us
Compiled sequential: 349.8425 us
正如您所见,TensorDictSequential
引入的开销已完全消除。
TensorDictModule 的注意事项¶
不要在
tensordict.nn
的模块周围使用Sequence
。这会破坏输入/输出键结构。始终尝试改用nn:TensorDictSequential
。不要将输出 tensordict 分配给新变量,因为输出 tensordict 只是原地修改的输入。分配新变量名称并非严格禁止,但这意味着您可能希望在其中一个被删除时两者都消失,而实际上垃圾回收器仍然可以看到工作空间中的张量,并且不会释放内存。
>>> tensordict = module(tensordict) # ok! >>> tensordict_out = module(tensordict) # don't!
处理分布: ProbabilisticTensorDictModule
¶
ProbabilisticTensorDictModule
是一个表示概率分布的非参数模块。分布参数从 tensordict 输入读取,输出写入输出 tensordict。输出根据规则采样,该规则由输入 default_interaction_type
参数和 interaction_type()
全局函数指定。如果它们冲突,则上下文管理器优先。
它可以与返回更新了分布参数的 tensordict 的 TensorDictModule
一起使用,通过 ProbabilisticTensorDictSequential
进行连接。这是 TensorDictSequential
的一种特殊情况,其最后一层是 ProbabilisticTensorDictModule
实例。
ProbabilisticTensorDictModule
负责构造分布(通过 get_dist()
方法)和/或从该分布采样(通过对模块的常规 forward 调用)。相同的 get_dist()
方法在 ProbabilisticTensorDictSequential
中公开。
可以在输出 tensordict 中找到参数,如果需要,也可以找到对数概率。
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist
td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
net,
extractor,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=dist.Normal,
return_log_prob=True,
),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
fields={
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
action_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
结论¶
我们已经看到 tensordict.nn 如何用于动态地即时构建复杂的神经网络架构。这为构建对模型签名不敏感的管道提供了可能性,即编写通用代码,以灵活的方式使用具有任意数量输入或输出的网络。
我们还看到了 dispatch
如何能够使用 tensordict.nn 来构建此类网络,并在不直接依赖 TensorDict
的情况下使用它们。得益于 compile()
,tensordict.nn.TensorDictSequential
引入的开销可以完全消除,为用户留下一个整洁、无 tensordict 的模块版本。
在下一教程中,我们将看到 torch.export
如何用于隔离和导出模块。
脚本总运行时间: (0 分钟 13.134 秒)