快捷方式

概述

TensorDict 使组织数据和编写可重用、通用的 PyTorch 代码变得容易。它最初是为 TorchRL 开发的,后来我们将其独立出来成为一个单独的库。

TensorDict 主要是一个字典,同时也是一个类似张量的类:它支持多种主要与形状和存储相关的张量操作。它的设计目标是能够高效地序列化或在节点之间、进程之间传输。最后,它附带了自己的 nn 模块,该模块与 torch.func 兼容,旨在使模型集成和参数操作更容易。

在本页中,我们将介绍 TensorDict 的动机,并举例说明它的功能。

动机

TensorDict 允许您编写通用的代码模块,这些模块可以在不同范式之间重用。例如,以下循环可以跨大多数 SL、SSL、UL 和 RL 任务重用。

>>> for i, tensordict in enumerate(dataset):
...     # the model reads and writes tensordicts
...     tensordict = model(tensordict)
...     loss = loss_module(tensordict)
...     loss.backward()
...     optimizer.step()
...     optimizer.zero_grad()

通过其 nn 模块,该包提供了许多工具,可以轻松地在代码库中使用 TensorDict

在多进程或分布式环境中,TensorDict 允许您无缝地将数据分派给每个工作进程。

>>> # creates batches of 10 datapoints
>>> splits = torch.arange(tensordict.shape[0]).split(10)
>>> for worker in range(workers):
...     idx = splits[worker]
...     pipe[worker].send(tensordict[idx])

TensorDict 提供的一些操作也可以通过 `tree_map` 完成,但复杂度更高。

>>> td = TensorDict(
...     {"a": torch.randn(3, 11), "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": td["a"], "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
...     {"a": regular_dicts["a"][i], "b": regular_dicts["b"][i]}
...     for i in range(3)]

嵌套情况甚至更具说服力。

>>> td = TensorDict(
...     {"a": {"c": torch.randn(3, 11)}, "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": {"c": td["a", "c"]}, "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
...     {"a": {"c": regular_dicts["a"]["c"][i]}, "b": regular_dicts["b"][i]}
...     for i in range(3)

在应用 `unbind` 操作后,将输出字典分解为三个结构相似的字典,在使用 `pytree` naively 处理时,会变得非常麻烦。使用 tensordict,我们为希望解绑或拆分嵌套结构的用户提供了一个简单的 API,而不是计算一个嵌套的拆分/解绑嵌套结构。

特性

A TensorDict 是一个类似字典的张量容器。要实例化一个 TensorDict,您可以指定键值对以及批次大小(可以通过 `TensorDict()` 创建一个空的 tensordict)。TensorDict 中任何值的领先维度必须与批次大小兼容。

>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict(
...     {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)},
...     batch_size=[2, 3],
... )

设置或检索值的语法与常规字典非常相似。

>>> zeros = tensordict["zeros"]
>>> tensordict["twos"] = 2 * torch.ones(2, 3)

还可以沿着批次大小索引一个 tensordict,这使得仅用几个字符就能获得一致的数据切片成为可能(请注意,使用 ellipsis 和 `tree_map` 索引第 n 个领先维度需要更多的编码)。

>>> sub_tensordict = tensordict[..., :2]

还可以使用 `inplace=True` 的 `set` 方法或 set_() 方法来就地更新内容。前者是后者的容错版本:如果没有找到匹配的键,它会写入一个新的键。

现在可以对 TensorDict 的内容进行集体操作。例如,要将所有内容放置到特定设备,只需执行以下操作:

>>> tensordict = tensordict.to("cuda:0")

然后您可以断言 tensordict 的设备是 `“cuda:0”`。

>>> assert tensordict.device == torch.device("cuda:0")

要重塑批次维度,可以这样做:

>>> tensordict = tensordict.reshape(6)

该类支持许多其他操作,包括 squeeze()unsqueeze()view()permute()unbind()stack()cat() 等等。

如果缺少某个操作,`apply()` 方法通常会提供所需的解决方案。

跳过形状操作

在某些情况下,可能希望在 TensorDict 中存储张量,而不在形状操作期间强制批次大小一致。

这可以通过将张量包装在 `UnbatchedTensor` 实例中来实现。

A `UnbatchedTensor` 在 TensorDict 上进行形状操作时会忽略其形状,从而允许灵活地存储和操作具有任意形状的张量。

>>> from tensordict import UnbatchedTensor
>>> tensordict = TensorDict({"zeros": UnbatchedTensor(torch.zeros(10))}, batch_size=[2, 3])
>>> reshaped_td = tensordict.reshape(6)
>>> reshaped_td["zeros"] is tensordict["zeros"]
True

非张量数据

Tensordict 是处理张量数据的强大库,但也支持非张量数据。本指南将向您展示如何使用 tensordict 处理非张量数据。

创建带有非张量数据的 TensorDict

您可以使用 `NonTensorData` 类创建带有非张量数据的 TensorDict。

>>> from tensordict import TensorDict, NonTensorData
>>> import torch
>>> td = TensorDict(
...     a=NonTensorData("a string!"),
...     b=torch.zeros(()),
... )
>>> print(td)
TensorDict(
    fields={
        a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

如您所见,`NonTensorData` 对象就像常规张量一样存储在 TensorDict 中。

`MetaData` 类可用于携带不可索引的数据,或不需要遵循 tensordict 批次大小的数据。

访问非张量数据

您可以使用键或 `get` 方法访问非张量数据。常规的 `getattr` 调用将返回 `NonTensorData` 对象的内容,而 `get()` 将返回 `NonTensorData` 对象本身。

>>> print(td["a"])  # prints: a string!
>>> print(td.get("a"))  # prints: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None)

批处理的非张量数据

如果您有批处理的非张量数据,则可以将其存储在具有指定批次大小的 TensorDict 中。

>>> td = TensorDict(
...     a=NonTensorData("a string!"),
...     b=torch.zeros(3),
...     batch_size=[3]
... )
>>> print(td)
TensorDict(
    fields={
        a: NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
        b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

在这种情况下,我们假设 tensordict 的所有元素都具有相同的非张量数据。

>>> print(td[0])
TensorDict(
    fields={
        a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

要为成形 tensordict 中的每个元素分配不同的非张量数据对象,可以使用非张量数据的堆栈。

堆叠的非张量数据

如果您有一个要存储在 `TensorDict` 中的非张量数据列表,您可以使用 `NonTensorStack` 类。

>>> td = TensorDict(
...     a=NonTensorStack("a string!", "another string!", "a third string!"),
...     b=torch.zeros(3),
...     batch_size=[3]
... )
>>> print(td)
TensorDict(
    fields={
        a: NonTensorStack(
            ['a string!', 'another string!', 'a third string!'...,
            batch_size=torch.Size([3]),
            device=None),
        b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

您可以访问第一个元素,您将获得字符串的第一个元素。

>>> print(td[0])
TensorDict(
    fields={
        a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

相反,在 `NonTensorData` 中使用列表不会产生相同的结果,因为无法确定如何处理恰好是列表的非张量数据。

>>> td = TensorDict(
...     a=NonTensorData(["a string!", "another string!", "a third string!"]),
...     b=torch.zeros(3),
...     batch_size=[3]
... )
>>> print(td[0])
TensorDict(
    fields={
        a: NonTensorData(data=['a string!', 'another string!', 'a third string!'], batch_size=torch.Size([]), device=None),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

堆叠带有非张量数据的 TensorDicts

要堆叠非张量数据,`stack()` 将创建一个 `NonTensorStack`。相反,在使用 `MetaData` 实例时,如果其内容匹配,则堆叠操作将产生单个 `MetaData` 实例。

>>> td = TensorDict(
...     a=NonTensorData("a string!"),
... b = torch.zeros(()),
... )
>>> print(torch.stack([td, td]))
TensorDict(
    fields={
        a: NonTensorStack(
            ['a string!', 'a string!'],
            batch_size=torch.Size([2]),
            device=None),
        b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
>>> td = TensorDict(
...     a=MetaData("a string!"),
... b = torch.zeros(()),
... )
>>> print(torch.stack([td, td]))
TensorDict(
    fields={
        a: MetaData(data=a string!, batch_size=torch.Size([2]), device=None),
        b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

命名维度

TensorDict 和相关类还支持维度名称。名称可以在构造时给出,也可以稍后进行细化。语义与 `torch.Tensor` 维度名称功能类似。

>>> tensordict = TensorDict({}, batch_size=[3, 4], names=["a", None])
>>> tensordict.refine_names(..., "b")
>>> tensordict.names = ["z", "y"]
>>> tensordict.rename("m", "n")
>>> tensordict.rename(m="h")

嵌套的 TensorDicts

A `TensorDict` 中的值本身可以是 TensorDicts(下面的示例中的嵌套字典将被转换为嵌套的 TensorDicts)。

>>> tensordict = TensorDict(
...     {
...         "inputs": {
...             "image": torch.rand(100, 28, 28),
...             "mask": torch.randint(2, (100, 28, 28), dtype=torch.uint8)
...         },
...         "outputs": {"logits": torch.randn(100, 10)},
...     },
...     batch_size=[100],
... )

可以使用字符串元组来访问或设置嵌套的键。

>>> image = tensordict["inputs", "image"]
>>> logits = tensordict.get(("outputs", "logits"))  # alternative way to access
>>> tensordict["outputs", "probabilities"] = torch.sigmoid(logits)

延迟求值

对 `TensorDict` 的某些操作会推迟执行,直到访问项目。例如,堆叠、挤压、升压、置换批次维度和创建视图不会立即在 `TensorDict` 的所有内容上执行。相反,它们是在访问 `TensorDict` 中的值时惰性执行的。这可以节省大量不必要的计算,特别是当 `TensorDict` 包含许多值时。

>>> tensordicts = [TensorDict({
...     "a": torch.rand(10),
...     "b": torch.rand(10, 1000, 1000)}, [10])
...     for _ in range(3)]
>>> stacked = torch.stack(tensordicts, 0)  # no stacking happens here
>>> stacked_a = stacked["a"]  # we stack the a values, b values are not stacked

它还有另一个优点,即我们可以操作堆栈中的原始 tensordicts。

>>> stacked["a"] = torch.zeros_like(stacked["a"])
>>> assert (tensordicts[0]["a"] == 0).all()

需要注意的是,`get` 方法现在变成了一个昂贵的操作,如果重复多次,可能会产生一些开销。可以通过在执行 `stack` 后调用 `tensordict.contiguous()` 来避免这种情况。为了进一步缓解此问题,TensorDict 附带了自己的元数据类(MetaTensor),该类可以跟踪字典中每个条目的类型、形状、dtype 和设备,而无需执行昂贵的操作。

延迟预分配

假设我们有一个函数 `foo()` -> `TensorDict`,并且我们执行如下操作:

>>> tensordict = TensorDict({}, batch_size=[N])
>>> for i in range(N):
...     tensordict[i] = foo()

当 `i == 0` 时,空的 `TensorDict` 将自动用批次大小为 N 的空张量填充。在循环的后续迭代中,更新将全部就地写入。

TensorDictModule

为了方便将 `TensorDict` 集成到代码库中,我们提供了 `tensordict.nn` 包,允许用户将 `TensorDict` 实例传递给 `Module` 对象(或任何可调用对象)。

`TensorDictModule` 包装了 `Module`,并接受单个 `TensorDict` 作为输入。您可以指定底层模块应从何处获取输入,以及何处写入其输出。这是我们能够编写可重用、通用的高级代码(如动机部分中的训练循环)的关键原因。

>>> from tensordict.nn import TensorDictModule
>>> class Net(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = nn.LazyLinear(1)
...
...     def forward(self, x):
...         logits = self.linear(x)
...         return logits, torch.sigmoid(logits)
>>> module = TensorDictModule(
...     Net(),
...     in_keys=["input"],
...     out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
... )
>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> tensordict = module(tensordict)
>>> # outputs can now be retrieved from the tensordict
>>> logits = tensordict["outputs", "logits"]
>>> probabilities = tensordict.get(("outputs", "probabilities"))

为了方便采用此类,还可以将张量作为关键字参数传递。

>>> tensordict = module(input=torch.randn(32, 100))

这将返回一个与上一个代码框中的 `TensorDict` 相同的 `TensorDict`。有关此功能的更多背景信息,请参阅 导出教程

许多 PyTorch 用户的痛点是 `nn.Sequential` 无法处理具有多个输入的模块。使用基于键的图可以轻松解决此问题,因为序列中的每个节点都知道需要读取哪些数据以及将数据写入何处。

为此,我们提供了 `TensorDictSequential` 类,该类将数据传递给一系列 `TensorDictModules`。序列中的每个模块都从原始 `TensorDict` 获取输入,并将其输出写入 `TensorDict`,这意味着序列中的模块可以忽略其前身的输出,或根据需要从 tensordict 获取其他输入。以下是一个示例:

>>> class Net(nn.Module):
...     def __init__(self, input_size=100, hidden_size=50, output_size=10):
...         super().__init__()
...         self.fc1 = nn.Linear(input_size, hidden_size)
...         self.fc2 = nn.Linear(hidden_size, output_size)
...
...     def forward(self, x):
...         x = torch.relu(self.fc1(x))
...         return self.fc2(x)
...
... class Masker(nn.Module):
...     def forward(self, x, mask):
...         return torch.softmax(x * mask, dim=1)
>>> net = TensorDictModule(
...     Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
...     Masker(),
...     in_keys=[("intermediate", "x"), ("input", "mask")],
...     out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>> tensordict = TensorDict(
...     {
...         "input": TensorDict(
...             {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
...             batch_size=[32],
...         )
...     },
...     batch_size=[32],
... )
>>> tensordict = module(tensordict)
>>> intermediate_x = tensordict["intermediate", "x"]
>>> probabilities = tensordict["output", "probabilities"]

在此示例中,第二个模块将第一个模块的输出与存储在 `TensorDict` 中的 `(“inputs”, “mask”)` 结合起来。

`TensorDictSequential` 提供了许多其他功能:可以通过查询 `in_keys` 和 `out_keys` 属性来访问输入和输出键的列表。还可以通过用所需的输入和输出键集查询 `select_subsequence()` 来请求子图。这将返回另一个 `TensorDictSequential`,其中仅包含满足这些要求所必需的模块。`TensorDictModule` 也兼容 `vmap()` 和其他 `torch.func` 功能。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源