• 文档 >
  • 操作 TensorDict 的形状
快捷方式

操纵 TensorDict 的形状

作者: Tom Begley

在本教程中,您将学习如何操纵 TensorDict 及其内容的形状。

当我们创建一个 TensorDict 时,我们会指定一个 batch_size,它必须与 TensorDict 中所有条目的前导维度一致。由于我们保证所有条目共享这些共同维度,因此 TensorDict 能够公开多种方法,我们可以用这些方法来操纵 TensorDict 及其内容的形状。

import torch
from tensordict.tensordict import TensorDict

索引 TensorDict

由于所有条目都保证存在批处理维度,因此我们可以随意索引它们,并且 TensorDict 的每个条目都将以相同的方式被索引。

a = torch.rand(3, 4)
b = torch.rand(3, 4, 5)
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4])

indexed_tensordict = tensordict[:2, 1]
assert indexed_tensordict["a"].shape == torch.Size([2])
assert indexed_tensordict["b"].shape == torch.Size([2, 5])

重塑 TensorDict

TensorDict.reshape 的工作方式与 torch.Tensor.reshape() 相同。它沿着批处理维度应用于 TensorDict 的所有内容 - 请注意下面示例中 b 的形状。它还会更新 batch_size 属性。

reshaped_tensordict = tensordict.reshape(-1)
assert reshaped_tensordict.batch_size == torch.Size([12])
assert reshaped_tensordict["a"].shape == torch.Size([12])
assert reshaped_tensordict["b"].shape == torch.Size([12, 5])

分割 TensorDict

TensorDict.split 类似于 torch.Tensor.split()。它将 TensorDict 分割成多个块。每个块都是一个 TensorDict,其结构与原始 TensorDict 相同,但其条目是原始 TensorDict 中相应条目的视图。

chunks = tensordict.split([3, 1], dim=1)
assert chunks[0].batch_size == torch.Size([3, 3])
assert chunks[1].batch_size == torch.Size([3, 1])
torch.testing.assert_close(chunks[0]["a"], tensordict["a"][:, :-1])

注意

每当函数或方法接受 dim 参数时,负数维度将相对于函数或方法所调用的 TensorDictbatch_size 进行解释。特别是,如果存在嵌套的 TensorDict 值且批处理大小不同,则负数维度始终相对于根的批处理维度进行解释。

>>> tensordict = TensorDict(
...     {
...         "a": torch.rand(3, 4),
...         "nested": TensorDict({"b": torch.rand(3, 4, 5)}, [3, 4, 5])
...     },
...     [3, 4],
... )
>>> # dim = -2 will be interpreted as the first dimension throughout, as the root
>>> # TensorDict has 2 batch dimensions, even though the nested TensorDict has 3
>>> chunks = tensordict.split([2, 1], dim=-2)
>>> assert chunks[0].batch_size == torch.Size([2, 4])
>>> assert chunks[0]["nested"].batch_size == torch.Size([2, 4, 5])

从这个例子可以看出,TensorDict.split 方法的行为与我们在调用之前将 dim=-2 替换为 dim=tensordict.batch_dims - 2 的行为完全相同。

Unbind

TensorDict.unbind 类似于 torch.Tensor.unbind(),并且在概念上与 TensorDict.split 相似。它会移除指定的维度,并返回沿该维度所有切片组成的 tuple

slices = tensordict.unbind(dim=1)
assert len(slices) == 4
assert all(s.batch_size == torch.Size([3]) for s in slices)
torch.testing.assert_close(slices[0]["a"], tensordict["a"][:, 0])

堆叠和连接 TensorDict

TensorDict 可以与 torch.cattorch.stack 结合使用。

堆叠 TensorDict

堆叠可以惰性地或连续地进行。惰性堆叠只是一个 tensordicts 列表,以 tensordicts 堆叠的形式呈现。它允许用户携带一个具有不同内容形状、设备或键集的 tensordicts 包。另一个优点是堆叠操作可能很耗时,并且如果只需要一小部分键,惰性堆叠将比真正的堆叠快得多。它依赖于 LazyStackedTensorDict 类。在这种情况下,值将在访问时才按需堆叠。

from tensordict import LazyStackedTensorDict

cloned_tensordict = tensordict.clone()
stacked_tensordict = LazyStackedTensorDict.lazy_stack(
    [tensordict, cloned_tensordict], dim=0
)
print(stacked_tensordict)

# Previously, torch.stack was always returning a lazy stack. For consistency with
# the regular PyTorch API, this behaviour will soon be adapted to deliver only
# dense tensordicts. To control which behaviour you are relying on, you can use
# the :func:`~tensordict.utils.set_lazy_legacy` decorator/context manager:

from tensordict.utils import set_lazy_legacy

with set_lazy_legacy(True):  # old behaviour
    lazy_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(lazy_stack, LazyStackedTensorDict)

with set_lazy_legacy(False):  # new behaviour
    dense_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(dense_stack, TensorDict)
LazyStackedTensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False,
    stack_dim=0)

如果我们沿堆叠维度索引一个 LazyStackedTensorDict,我们将恢复原始的 TensorDict

assert stacked_tensordict[0] is tensordict
assert stacked_tensordict[1] is cloned_tensordict

访问 LazyStackedTensorDict 中的键会导致这些值被堆叠。如果键对应于嵌套的 TensorDict,那么我们将恢复另一个 LazyStackedTensorDict

assert stacked_tensordict["a"].shape == torch.Size([2, 3, 4])

注意

由于值是按需堆叠的,多次访问一个项将导致它被多次堆叠,这是低效的。如果您需要多次访问堆叠的 TensorDict 中的某个项,您可能需要考虑将 LazyStackedTensorDict 转换为连续的 TensorDict,这可以通过 LazyStackedTensorDict.to_tensordictLazyStackedTensorDict.contiguous 方法来完成。

>>> assert isinstance(stacked_tensordict.contiguous(), TensorDict)
>>> assert isinstance(stacked_tensordict.contiguous(), TensorDict)

调用这些方法之一后,我们将得到一个包含堆叠值的常规 TensorDict,并且在访问值时不会执行额外的计算。

连接 TensorDict

连接不是惰性完成的,而是将 TensorDict 实例列表传递给 torch.cat() 会简单地返回一个 TensorDict,其条目是列表中元素的连接条目。

concatenated_tensordict = torch.cat([tensordict, cloned_tensordict], dim=0)
assert isinstance(concatenated_tensordict, TensorDict)
assert concatenated_tensordict.batch_size == torch.Size([6, 4])
assert concatenated_tensordict["b"].shape == torch.Size([6, 4, 5])

展开 TensorDict

我们可以使用 TensorDict.expand 展开 TensorDict 的所有条目。

exp_tensordict = tensordict.expand(2, *tensordict.batch_size)
assert exp_tensordict.batch_size == torch.Size([2, 3, 4])
torch.testing.assert_close(exp_tensordict["a"][0], exp_tensordict["a"][1])

压缩和解压 TensorDict

我们可以使用 squeeze()unsqueeze() 方法来压缩或解压 TensorDict 的内容。

tensordict = TensorDict({"a": torch.rand(3, 1, 4)}, [3, 1, 4])
squeezed_tensordict = tensordict.squeeze()
assert squeezed_tensordict["a"].shape == torch.Size([3, 4])
print(squeezed_tensordict, end="\n\n")

unsqueezed_tensordict = tensordict.unsqueeze(-1)
assert unsqueezed_tensordict["a"].shape == torch.Size([3, 1, 4, 1])
print(unsqueezed_tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 4]),
    device=None,
    is_shared=False)

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 1, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 1, 4, 1]),
    device=None,
    is_shared=False)

注意

到目前为止,诸如 unsqueeze()squeeze()view()permute()transpose() 等操作,都返回这些操作的惰性版本(即一个容器,其中存储了原始 tensordict,并且在每次访问键时都会应用操作)。这种行为将在未来被弃用,并且可以通过 set_lazy_legacy() 函数来控制。

>>> with set_lazy_legacy(True):
...     lazy_unsqueeze = tensordict.unsqueeze(0)
>>> with set_lazy_legacy(False):
...     dense_unsqueeze = tensordict.unsqueeze(0)

请记住,一如既往,这些方法仅适用于批处理维度。条目的任何非批处理维度都不会受到影响。

tensordict = TensorDict({"a": torch.rand(3, 1, 1, 4)}, [3, 1])
squeezed_tensordict = tensordict.squeeze()
# only one of the singleton dimensions is dropped as the other
# is not a batch dimension
assert squeezed_tensordict["a"].shape == torch.Size([3, 1, 4])

查看 TensorDict

TensorDict 也支持 view。这会创建一个 _ViewedTensorDict,在访问其内容时惰性地创建其内容的视图。

tensordict = TensorDict({"a": torch.arange(12)}, [12])
# no views are created at this step
viewed_tensordict = tensordict.view((2, 3, 2))

# the view of "a" is created on-demand when we access it
assert viewed_tensordict["a"].shape == torch.Size([2, 3, 2])

排列批处理维度

torch.permute() 类似,可以使用 TensorDict.permute 方法来排列批处理维度。非批处理维度保持不变。

此操作是惰性的,因此批处理维度仅在我们尝试访问条目时才被排列。一如既往,如果您可能需要多次访问特定项,请考虑将其转换为 TensorDict

tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])
# swap the batch dimensions
permuted_tensordict = tensordict.permute([1, 0])

assert permuted_tensordict["a"].shape == torch.Size([4, 3])
assert permuted_tensordict["b"].shape == torch.Size([4, 3, 5])

将 tensordicts 用作装饰器

对于许多可逆操作,tensordicts 可用作装饰器。这些操作包括用于函数调用的 to_module()unlock_()lock_(),或者像 view()permute() transpose()squeeze()unsqueeze() 这样的形状操作。以下是 transpose 函数的一个快速示例。

tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])

with tensordict.transpose(1, 0) as tdt:
    tdt.set("c", torch.ones(4, 3))  # we have permuted the dims

# the ``"c"`` entry is now in the tensordict we used as decorator:
#

assert (tensordict.get("c") == 1).all()

在 TensorDict 中收集值

torch.gather() 类似,可以使用 TensorDict.gather 方法沿批处理维度进行索引并将结果收集到单个维度中。

index = torch.randint(4, (3, 4))
gathered_tensordict = tensordict.gather(dim=1, index=index)
print("index:\n", index, end="\n\n")
print("tensordict['a']:\n", tensordict["a"], end="\n\n")
print("gathered_tensordict['a']:\n", gathered_tensordict["a"], end="\n\n")
index:
 tensor([[3, 1, 3, 1],
        [1, 1, 3, 2],
        [1, 2, 2, 0]])

tensordict['a']:
 tensor([[0.9659, 0.8909, 0.0103, 0.9228],
        [0.6158, 0.2101, 0.8202, 0.9474],
        [0.1096, 0.8732, 0.3140, 0.0605]])

gathered_tensordict['a']:
 tensor([[0.9228, 0.8909, 0.9228, 0.8909],
        [0.2101, 0.2101, 0.9474, 0.8202],
        [0.8732, 0.3140, 0.3140, 0.1096]])

脚本总运行时间: (0 分 0.008 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源