快捷方式

tensordict 包

TensorDict 类通过打包成一个继承自常规 PyTorch 张量特性的类字典对象,简化了在模块之间传递多个张量的过程。

TensorDictBase()

TensorDictBase 是 TensorDict 的抽象父类,TensorDict 是一个 torch.Tensor 数据容器。

TensorDict([source, batch_size, device, ...])

张量字典。

LazyStackedTensorDict(*tensordicts[, ...])

TensorDict 的懒惰堆叠。

PersistentTensorDict(*[, batch_size, ...])

持久化 TensorDict 实现。

TensorDictParams([parameters, no_convert, lock])

用于暴露参数的 TensorDictBase 包装器。

get_defaults_to_none([set_to_none])

返回 `get` 默认值的状态。

构造函数和处理器

该库提供了几种与 numpy 结构化数组、命名元组或 h5 文件等其他数据结构进行交互的方法。该库还公开了专门用于操作 tensordict 的函数,例如 `save`、`load`、`stack` 或 `cat`。

cat(input[, dim, out])

将tensordicts沿给定维度连接成一个tensordict。

default_is_leaf(cls)

如果一个类型不是张量集合(tensordict 或 tensorclass),则返回 `True`。

from_any(obj, *[, auto_batch_size, ...])

将任何对象转换为 TensorDict。

from_consolidated(filename)

从合并文件中重构 tensordict。

from_dict(d, *[, auto_batch_size, ...])

将字典转换为 TensorDict。

from_h5(h5_file, *[, auto_batch_size, ...])

将 HDF5 文件转换为 TensorDict。

from_module(module[, as_module, lock, ...])

将模块的参数和缓冲区复制到 tensordict 中。

from_modules(*modules[, as_module, lock, ...])

为 vmap 的 ensemable 学习/特征期望应用检索多个模块的参数。

from_namedtuple(named_tuple, *[, ...])

将命名元组转换为 TensorDict。

from_pytree(pytree, *[, batch_size, ...])

将 pytree 转换为 TensorDict 实例。

from_struct_array(struct_array, *[, ...])

将结构化 numpy 数组转换为 TensorDict。

from_tuple(obj, *[, auto_batch_size, ...])

将元组转换为 TensorDict。

fromkeys(keys[, value])

从键列表和单个值创建 tensordict。

is_batchedtensor(arg0)

is_leaf_nontensor(cls)

如果一个类型不是张量集合(tensordict 或 tensorclass)或是非张量,则返回 `True`。

lazy_stack(input[, dim, out])

创建 TensorDicts 的懒惰堆叠。

load(prefix[, device, non_blocking, out])

从磁盘加载 tensordict。

load_memmap(prefix[, device, non_blocking, out])

从磁盘加载内存映射的 tensordict。

maybe_dense_stack(input[, dim, out])

尝试使 TensorDicts 密集堆叠,并在需要时回退到懒惰堆叠。

memmap(data[, prefix, copy_existing, ...])

将所有张量写入内存映射的 Tensor 中,并放入新的 tensordict。

save(data[, prefix, copy_existing, ...])

将tensordict保存到磁盘。

stack(input[, dim, out])

沿给定维度将 tensordicts 堆叠成一个单一的 tensordict。

TensorDict 作为上下文管理器

TensorDict 可用作上下文管理器,用于需要执行然后撤销的操作。这包括临时锁定/解锁 tensordict

>>> data.lock_()  # data.set will result in an exception
>>> with data.unlock_():
...     data.set("key", value)
>>> assert data.is_locked()

或使用包含模型参数和缓冲区的 TensorDict 实例执行函数调用。

>>> params = TensorDict.from_module(module).clone()
>>> params.zero_()
>>> with params.to_module(module):
...     y = module(x)

在第一个示例中,我们可以修改 tensordict `data`,因为我们已经临时解锁了它。在第二个示例中,我们使用 `params` tensordict 实例中包含的参数和缓冲区填充模块,并在该调用完成后重置原始参数。

内存映射张量

`tensordict` 提供了 `MemoryMappedTensor` 原语,允许您方便地处理存储在物理内存中的张量。`MemoryMappedTensor` 的主要优点是其易于构造(无需处理张量的存储)、能够处理不适合内存的大型连续数据、跨进程的高效(反)序列化以及对存储张量的高效索引。

如果所有工作进程都可以访问相同的存储(在多进程和分布式设置中),则传递 `MemoryMappedTensor` 将仅包括传递对磁盘文件的引用以及大量用于重建它的元数据。对于索引内存映射张量也是如此,只要它们的存储的数据指针与原始数据指针相同。

索引内存映射张量比从磁盘加载多个独立文件要快得多,并且不需要将数组的全部内容加载到内存中。然而,PyTorch 张量的物理存储不应有任何不同。

>>> my_images = MemoryMappedTensor.empty((1_000_000, 3, 480, 480), dtype=torch.unint8)
>>> mini_batch = my_images[:10]  # just reads the first 10 images of the dataset

MemoryMappedTensor(source, *[, dtype, ...])

内存映射张量。

逐点操作

Tensordict 支持各种逐点操作,允许您对其中存储的张量执行逐元素计算。这些操作类似于对常规 PyTorch 张量执行的操作。

支持的操作

目前支持以下逐点操作:

  • 左乘和右乘(+)

  • 左减和右减(-)

  • 左乘和右乘(*)

  • 左除和右除(/)

  • 左幂(**)

还支持许多其他操作,例如 `clamp()`、`sqrt()` 等。

执行逐点操作

您可以在两个 Tensordict 之间,或在 Tensordict 和张量/标量值之间执行逐点操作。

示例 1:Tensordict-Tensordict 操作

>>> import torch
>>> from tensordict import TensorDict
>>> td1 = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> td2 = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> result = td1 * td2

在此示例中,`*` 运算符逐元素应用于 td1 和 td2 中的相应张量。

示例 2:Tensordict-Tensor 操作

>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> tensor = torch.randn(4)
>>> result = td * tensor

在此,`*` 运算符逐元素应用于 td 中的每个张量和提供的张量。该张量被广播以匹配 Tensordict 中每个张量的形状。

示例 3:Tensordict-Scalar 操作

>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> scalar = 2.0
>>> result = td * scalar

在这种情况下,`*` 运算符逐元素应用于 td 中的每个张量和提供的标量。

广播规则

当在 Tensordict 和张量/标量之间执行逐点操作时,张量/标量会被广播以匹配 Tensordict 中每个张量的形状:张量首先在左侧广播以匹配 tensordict 的形状,然后单独在右侧广播以匹配张量的形状。如果您将 `TensorDict` 视为单个张量实例,则这遵循 PyTorch 中使用的标准广播规则。

例如,如果您有一个形状为 `(3, 4)` 的 TensorDict,并将其乘以形状为 `(4,)` 的张量,则该张量将在应用操作之前被广播到形状 `(3, 4)`。如果 tensordict 包含一个形状为 `(3, 4, 5)` 的张量,则用于乘法的张量将在右侧广播到形状 `(3, 4, 5)` 以进行该乘法。

如果逐点操作跨多个 tensordict 执行并且它们的批次大小不同,它们将被广播到公共形状。

逐点操作的效率

在可能的情况下,将使用 `torch._foreach_` 融合内核来加速逐点操作的计算。

处理缺失条目

在两个 Tensordict 之间执行逐点操作时,它们必须具有相同的键。一些操作,例如 `add()`,有一个 `default` 关键字参数,可用于与具有独占条目的 tensordict 进行操作。如果 `default=None`(默认值),则两个 Tensordict 必须具有完全匹配的键集。如果 `default="intersection"`,则仅考虑交集的键集,其他键将被忽略。在所有其他情况下,`default` 将用于操作两侧的所有缺失条目。

实用工具

utils.expand_as_right(tensor, dest)

将张量向右扩展以匹配另一个张量形状。

utils.expand_right(tensor, shape)

将张量向右扩展以匹配所需形状。

utils.isin(input, reference, key[, dim])

测试 `input` `dim` 中的 `key` 的每个元素是否也存在于 `reference` 中。

utils.remove_duplicates(input, key[, dim, ...])

在指定维度上移除 `key` 中重复的索引。

capture_non_tensor_stack([allow_none])

获取捕获非张量堆栈的当前设置。

dense_stack_tds(td_list[, dim])

密集堆叠一个列表的 `TensorDictBase` 对象(或 `LazyStackedTensorDict`),前提是它们具有相同的结构。

is_batchedtensor(arg0)

is_tensor_collection(datatype)

检查数据对象或类型是否是 tensordict 库中的张量容器。

lazy_legacy([allow_none])

如果为选定的方法使用懒惰转换,则返回 `True`。

make_tensordict([input_dict, batch_size, ...])

从关键字参数或输入字典返回一个创建的 TensorDict。

merge_tensordicts(*tensordicts[, callback_exist])

将 tensordicts 合并在一起。

pad(tensordict, pad_size[, value])

使用常量值沿批次维度填充 tensordict 中的所有张量,并返回一个新的 tensordict。

pad_sequence(list_of_tensordicts[, pad_dim, ...])

填充 tensordict 列表,以便它们可以以连续格式堆叠在一起。

parse_tensor_dict_string(s)

将 TensorDict repr 解析为 TensorDict。

set_capture_non_tensor_stack(mode)

一个上下文管理器或装饰器,用于控制是否应将相同的非张量数据堆叠到单个 NonTensorData 对象或 NonTensorStack 中。

set_lazy_legacy(mode)

将某些方法的行为设置为懒惰转换。

set_list_to_stack(mode)

用于控制 TensorDict 中列表处理行为的上下文管理器和装饰器。

list_to_stack([allow_none])

检索 TensorDict 中列表到堆栈转换的当前设置。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源