概述¶
TensorDict 使组织数据和编写可重用、通用的 PyTorch 代码变得容易。它最初是为 TorchRL 开发的,现在已作为独立库发布。
TensorDict 主要是一个字典,同时也是一个类似张量的类:它支持多种主要与形状和存储相关的张量操作。它旨在能够有效地序列化或从节点到节点、进程到进程传输。最后,它还附带了自己的 nn
模块,该模块与 torch.func
兼容,旨在简化模型集成和参数操作。
在本页面,我们将介绍 TensorDict
的动机,并提供一些它功能的示例。
动机¶
TensorDict 允许您编写可在不同范式之间重用的通用代码模块。例如,以下循环可以用于大多数 SL、SSL、UL 和 RL 任务。
>>> for i, tensordict in enumerate(dataset):
... # the model reads and writes tensordicts
... tensordict = model(tensordict)
... loss = loss_module(tensordict)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
通过其 nn
模块,该包提供了许多工具,可以轻松地在代码库中使用 TensorDict
。
在多进程或分布式环境中,TensorDict
允许您无缝地将数据分发给每个工作进程。
>>> # creates batches of 10 datapoints
>>> splits = torch.arange(tensordict.shape[0]).split(10)
>>> for worker in range(workers):
... idx = splits[worker]
... pipe[worker].send(tensordict[idx])
TensorDict 提供的一些操作也可以通过 tree_map 完成,但复杂度更高。
>>> td = TensorDict(
... {"a": torch.randn(3, 11), "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": td["a"], "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
... {"a": regular_dicts["a"][i], "b": regular_dicts["b"][i]}
... for i in range(3)]
嵌套情况更具说服力。
>>> td = TensorDict(
... {"a": {"c": torch.randn(3, 11)}, "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": {"c": td["a", "c"]}, "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
... {"a": {"c": regular_dicts["a"]["c"][i]}, "b": regular_dicts["b"][i]}
... for i in range(3)
在对 pytree 进行原始操作时,在应用 unbind 操作后将输出字典分解为三个结构相似的字典,会变得非常麻烦。使用 tensordict,我们为想要 unbind 或拆分嵌套结构的用户的提供了简单的 API,而不是计算嵌套的拆分/unbind 嵌套结构。
特性¶
一个 TensorDict
是一个类似字典的张量容器。要实例化一个 TensorDict
,您可以指定键值对以及批次大小(可以通过 TensorDict() 创建一个空的 tensordict)。TensorDict
中任何值的领先维度必须与批次大小兼容。
>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict(
... {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)},
... batch_size=[2, 3],
... )
设置或检索值的语法与常规字典非常相似。
>>> zeros = tensordict["zeros"]
>>> tensordict["twos"] = 2 * torch.ones(2, 3)
还可以沿其 batch_size 索引一个 tensordict,这使得只需几个字符即可获得数据的同等切片(请注意,使用 ellipsis 和 tree_map 索引第 n 个领先维度需要更多的编码)。
>>> sub_tensordict = tensordict[..., :2]
还可以使用 inplace=True
的 set 方法或 set_()
方法进行原地更新。前者是后者的容错版本:如果找不到匹配的键,它将写入一个新键。
现在可以集体操作 TensorDict 的内容。例如,要将所有内容放置到特定设备,只需执行以下操作:
>>> tensordict = tensordict.to("cuda:0")
然后,您可以断言 tensordict 的设备是 “cuda:0”。
>>> assert tensordict.device == torch.device("cuda:0")
要重塑批次维度,您可以执行以下操作:
>>> tensordict = tensordict.reshape(6)
该类还支持许多其他操作,包括 squeeze()
、unsqueeze()
、view()
、permute()
、unbind()
、stack()
、cat()
等等。
如果缺少某个操作,apply()
方法通常会提供所需解决方案。
规避形状操作¶
在某些情况下,可能希望在不强制形状操作期间批次大小一致性的情况下将张量存储在 TensorDict 中。
这可以通过将张量包装在 UnbatchedTensor
实例中来实现。
UnbatchedTensor
在 TensorDict 的形状操作期间会忽略其形状,从而允许灵活地存储和操作任意形状的张量。
>>> from tensordict import UnbatchedTensor
>>> tensordict = TensorDict({"zeros": UnbatchedTensor(torch.zeros(10))}, batch_size=[2, 3])
>>> reshaped_td = tensordict.reshape(6)
>>> reshaped_td["zeros"] is tensordict["zeros"]
True
非张量数据¶
Tensordict 是一个强大的处理张量数据的库,但也支持非张量数据。本指南将向您展示如何使用 tensordict 处理非张量数据。
使用非张量数据创建 TensorDict¶
您可以使用 NonTensorData
类创建包含非张量数据的 TensorDict。
>>> from tensordict import TensorDict, NonTensorData
>>> import torch
>>> td = TensorDict(
... a=NonTensorData("a string!"),
... b=torch.zeros(()),
... )
>>> print(td)
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
如您所见,NonTensorData
对象就像普通张量一样存储在 TensorDict 中。
MetaData
类可用于承载不可索引的数据,或不需要遵循 tensordict 批次大小的数据。
访问非张量数据¶
您可以使用键或 get 方法访问非张量数据。常规的 getattr 调用将返回 NonTensorData
对象的内容,而 get()
将返回 NonTensorData
对象本身。
>>> print(td["a"]) # prints: a string!
>>> print(td.get("a")) # prints: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None)
批处理的非张量数据¶
如果您有一个非张量数据的批次,您可以将其存储在具有指定批次大小的 TensorDict 中。
>>> td = TensorDict(
... a=NonTensorData("a string!"),
... b=torch.zeros(3),
... batch_size=[3]
... )
>>> print(td)
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
在这种情况下,我们假设 tensordict 的所有元素都具有相同的非张量数据。
>>> print(td[0])
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
要为已排序 tensordict 中的每个元素分配不同的非张量数据对象,可以使用非张量数据堆栈。
堆叠的非张量数据¶
如果您有一个要存储在 TensorDict
中的非张量数据列表,您可以使用 NonTensorStack
类。
>>> td = TensorDict(
... a=NonTensorStack("a string!", "another string!", "a third string!"),
... b=torch.zeros(3),
... batch_size=[3]
... )
>>> print(td)
TensorDict(
fields={
a: NonTensorStack(
['a string!', 'another string!', 'a third string!'...,
batch_size=torch.Size([3]),
device=None),
b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
您可以访问第一个元素,您将获得第一个字符串。
>>> print(td[0])
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
相比之下,将 NonTensorData
与列表一起使用不会得到相同的结果,因为无法确定如何普遍处理恰好是列表的非张量数据。
>>> td = TensorDict(
... a=NonTensorData(["a string!", "another string!", "a third string!"]),
... b=torch.zeros(3),
... batch_size=[3]
... )
>>> print(td[0])
TensorDict(
fields={
a: NonTensorData(data=['a string!', 'another string!', 'a third string!'], batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
堆叠带有非张量数据的 TensorDict¶
为了堆叠非张量数据,stack()
将创建一个 NonTensorStack
。相比之下,在使用 MetaData
实例时,如果内容匹配,堆叠操作将生成单个 MetaData 实例。
>>> td = TensorDict(
... a=NonTensorData("a string!"),
... b = torch.zeros(()),
... )
>>> print(torch.stack([td, td]))
TensorDict(
fields={
a: NonTensorStack(
['a string!', 'a string!'],
batch_size=torch.Size([2]),
device=None),
b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
>>> td = TensorDict(
... a=MetaData("a string!"),
... b = torch.zeros(()),
... )
>>> print(torch.stack([td, td]))
TensorDict(
fields={
a: MetaData(data=a string!, batch_size=torch.Size([2]), device=None),
b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
命名维度¶
TensorDict 和相关类也支持维度名称。名称可以在构造时给出,也可以稍后进行细化。其语义与 torch.Tensor 的维度名称功能类似。
>>> tensordict = TensorDict({}, batch_size=[3, 4], names=["a", None])
>>> tensordict.refine_names(..., "b")
>>> tensordict.names = ["z", "y"]
>>> tensordict.rename("m", "n")
>>> tensordict.rename(m="h")
嵌套 TensorDicts¶
TensorDict
中的值本身可以是 TensorDicts(下面示例中的嵌套字典将被转换为嵌套 TensorDicts)。
>>> tensordict = TensorDict(
... {
... "inputs": {
... "image": torch.rand(100, 28, 28),
... "mask": torch.randint(2, (100, 28, 28), dtype=torch.uint8)
... },
... "outputs": {"logits": torch.randn(100, 10)},
... },
... batch_size=[100],
... )
可以通过字符串元组访问或设置嵌套键。
>>> image = tensordict["inputs", "image"]
>>> logits = tensordict.get(("outputs", "logits")) # alternative way to access
>>> tensordict["outputs", "probabilities"] = torch.sigmoid(logits)
延迟评估¶
TensorDict
上的一些操作会推迟执行,直到访问项目。例如,堆叠、挤压、扩展、置换批次维度和创建视图不会立即在 TensorDict
的所有内容上执行。相反,它们会在访问 TensorDict
中的值时惰性执行。如果 TensorDict
包含许多值,这可以节省大量不必要的计算。
>>> tensordicts = [TensorDict({
... "a": torch.rand(10),
... "b": torch.rand(10, 1000, 1000)}, [10])
... for _ in range(3)]
>>> stacked = torch.stack(tensordicts, 0) # no stacking happens here
>>> stacked_a = stacked["a"] # we stack the a values, b values are not stacked
它还有一个优点是我们可以操纵堆栈中的原始 tensordicts。
>>> stacked["a"] = torch.zeros_like(stacked["a"])
>>> assert (tensordicts[0]["a"] == 0).all()
需要注意的是,get 方法现在变成了一个昂贵的操作,如果重复多次,可能会导致一些开销。可以通过在 stack 执行后简单调用 tensordict.contiguous() 来避免此问题。为了进一步缓解此问题,TensorDict 提供了自己的元数据类(MetaTensor),该类跟踪字典每个条目的类型、形状、dtype 和设备,而无需执行昂贵的操作。
延迟预分配¶
假设我们有一个函数 foo() -> TensorDict,并且我们执行以下操作:
>>> tensordict = TensorDict({}, batch_size=[N])
>>> for i in range(N):
... tensordict[i] = foo()
当 i == 0
时,空的 TensorDict
将自动填充具有批次大小 N 的空张量。在循环的后续迭代中,所有更新都将原地写入。
TensorDictModule¶
为了便于将 TensorDict
集成到代码库中,我们提供了 tensordict.nn 包,允许用户将 TensorDict
实例传递给 Module
对象(或任何可调用对象)。
TensorDictModule
包装了 Module
,并接受一个 TensorDict
作为输入。您可以指定底层模块应从何处获取输入,以及应将输出写入何处。这是我们能够编写可重用的、通用的高级代码(如动机部分中的训练循环)的关键原因。
>>> from tensordict.nn import TensorDictModule
>>> class Net(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.LazyLinear(1)
...
... def forward(self, x):
... logits = self.linear(x)
... return logits, torch.sigmoid(logits)
>>> module = TensorDictModule(
... Net(),
... in_keys=["input"],
... out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
... )
>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> tensordict = module(tensordict)
>>> # outputs can now be retrieved from the tensordict
>>> logits = tensordict["outputs", "logits"]
>>> probabilities = tensordict.get(("outputs", "probabilities"))
为了方便采用此类,您还可以将张量作为 kwargs 传递。
>>> tensordict = module(input=torch.randn(32, 100))
这将返回一个与上一个代码框中的 TensorDict
相同的 TensorDict
。有关此功能的更多背景信息,请参阅 导出教程。
许多 PyTorch 用户面临的一个主要痛点是 nn.Sequential 无法处理具有多个输入的模块。使用基于键的图可以轻松解决此问题,因为序列中的每个节点都知道需要读取哪些数据以及写入何处。
为此,我们提供了 TensorDictSequential
类,它将数据传递给一系列 TensorDictModules
。序列中的每个模块都从原始 TensorDict
获取输入,并将输出写入其中,这意味着序列中的模块可以忽略其前驱的输出,或根据需要从 tensordict 获取额外的输入。这是一个例子:
>>> class Net(nn.Module):
... def __init__(self, input_size=100, hidden_size=50, output_size=10):
... super().__init__()
... self.fc1 = nn.Linear(input_size, hidden_size)
... self.fc2 = nn.Linear(hidden_size, output_size)
...
... def forward(self, x):
... x = torch.relu(self.fc1(x))
... return self.fc2(x)
...
... class Masker(nn.Module):
... def forward(self, x, mask):
... return torch.softmax(x * mask, dim=1)
>>> net = TensorDictModule(
... Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
... Masker(),
... in_keys=[("intermediate", "x"), ("input", "mask")],
... out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>> tensordict = TensorDict(
... {
... "input": TensorDict(
... {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
... batch_size=[32],
... )
... },
... batch_size=[32],
... )
>>> tensordict = module(tensordict)
>>> intermediate_x = tensordict["intermediate", "x"]
>>> probabilities = tensordict["output", "probabilities"]
在此示例中,第二个模块将第一个模块的输出与存储在 TensorDict
的 (“inputs”, “mask”) 下的掩码结合起来。
TensorDictSequential
提供了许多其他功能:可以通过查询 in_keys 和 out_keys 属性来访问输入和输出键的列表。还可以通过使用所需的输入和输出键集查询 select_subsequence()
来请求子图。这将返回另一个 TensorDictSequential
,其中仅包含满足这些要求所必需的模块。 TensorDictModule
也兼容 vmap()
和其他 torch.func
功能。