tensordict 包¶
该 TensorDict
类通过打包成一个继承自常规 PyTorch 张量特性的类字典对象,简化了在模块之间传递多个张量的过程。
TensorDictBase 是 TensorDict 的抽象父类,TensorDict 是一个 torch.Tensor 数据容器。 |
|
|
张量字典。 |
|
TensorDict 的懒惰堆叠。 |
|
持久化 TensorDict 实现。 |
|
用于暴露参数的 TensorDictBase 包装器。 |
|
返回 `get` 默认值的状态。 |
构造函数和处理器¶
该库提供了几种与 numpy 结构化数组、命名元组或 h5 文件等其他数据结构进行交互的方法。该库还公开了专门用于操作 tensordict 的函数,例如 `save`、`load`、`stack` 或 `cat`。
|
将tensordicts沿给定维度连接成一个tensordict。 |
|
如果一个类型不是张量集合(tensordict 或 tensorclass),则返回 `True`。 |
|
将任何对象转换为 TensorDict。 |
|
从合并文件中重构 tensordict。 |
|
将字典转换为 TensorDict。 |
|
将 HDF5 文件转换为 TensorDict。 |
|
将模块的参数和缓冲区复制到 tensordict 中。 |
|
为 vmap 的 ensemable 学习/特征期望应用检索多个模块的参数。 |
|
将命名元组转换为 TensorDict。 |
|
将 pytree 转换为 TensorDict 实例。 |
|
将结构化 numpy 数组转换为 TensorDict。 |
|
将元组转换为 TensorDict。 |
|
从键列表和单个值创建 tensordict。 |
|
|
|
如果一个类型不是张量集合(tensordict 或 tensorclass)或是非张量,则返回 `True`。 |
|
创建 TensorDicts 的懒惰堆叠。 |
|
从磁盘加载 tensordict。 |
|
从磁盘加载内存映射的 tensordict。 |
|
尝试使 TensorDicts 密集堆叠,并在需要时回退到懒惰堆叠。 |
|
将所有张量写入内存映射的 Tensor 中,并放入新的 tensordict。 |
|
将tensordict保存到磁盘。 |
|
沿给定维度将 tensordicts 堆叠成一个单一的 tensordict。 |
TensorDict 作为上下文管理器¶
TensorDict
可用作上下文管理器,用于需要执行然后撤销的操作。这包括临时锁定/解锁 tensordict
>>> data.lock_() # data.set will result in an exception
>>> with data.unlock_():
... data.set("key", value)
>>> assert data.is_locked()
或使用包含模型参数和缓冲区的 TensorDict 实例执行函数调用。
>>> params = TensorDict.from_module(module).clone()
>>> params.zero_()
>>> with params.to_module(module):
... y = module(x)
在第一个示例中,我们可以修改 tensordict `data`,因为我们已经临时解锁了它。在第二个示例中,我们使用 `params` tensordict 实例中包含的参数和缓冲区填充模块,并在该调用完成后重置原始参数。
内存映射张量¶
`tensordict` 提供了 `MemoryMappedTensor` 原语,允许您方便地处理存储在物理内存中的张量。`MemoryMappedTensor` 的主要优点是其易于构造(无需处理张量的存储)、能够处理不适合内存的大型连续数据、跨进程的高效(反)序列化以及对存储张量的高效索引。
如果所有工作进程都可以访问相同的存储(在多进程和分布式设置中),则传递 `MemoryMappedTensor` 将仅包括传递对磁盘文件的引用以及大量用于重建它的元数据。对于索引内存映射张量也是如此,只要它们的存储的数据指针与原始数据指针相同。
索引内存映射张量比从磁盘加载多个独立文件要快得多,并且不需要将数组的全部内容加载到内存中。然而,PyTorch 张量的物理存储不应有任何不同。
>>> my_images = MemoryMappedTensor.empty((1_000_000, 3, 480, 480), dtype=torch.unint8)
>>> mini_batch = my_images[:10] # just reads the first 10 images of the dataset
|
内存映射张量。 |
逐点操作¶
Tensordict 支持各种逐点操作,允许您对其中存储的张量执行逐元素计算。这些操作类似于对常规 PyTorch 张量执行的操作。
支持的操作¶
目前支持以下逐点操作:
左乘和右乘(+)
左减和右减(-)
左乘和右乘(*)
左除和右除(/)
左幂(**)
还支持许多其他操作,例如 `clamp()`、`sqrt()` 等。
执行逐点操作¶
您可以在两个 Tensordict 之间,或在 Tensordict 和张量/标量值之间执行逐点操作。
示例 1:Tensordict-Tensordict 操作¶
>>> import torch
>>> from tensordict import TensorDict
>>> td1 = TensorDict(
... a=torch.randn(3, 4),
... b=torch.zeros(3, 4, 5),
... c=torch.ones(3, 4, 5, 6),
... batch_size=(3, 4),
... )
>>> td2 = TensorDict(
... a=torch.randn(3, 4),
... b=torch.zeros(3, 4, 5),
... c=torch.ones(3, 4, 5, 6),
... batch_size=(3, 4),
... )
>>> result = td1 * td2
在此示例中,`*` 运算符逐元素应用于 td1 和 td2 中的相应张量。
示例 2:Tensordict-Tensor 操作¶
>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict(
... a=torch.randn(3, 4),
... b=torch.zeros(3, 4, 5),
... c=torch.ones(3, 4, 5, 6),
... batch_size=(3, 4),
... )
>>> tensor = torch.randn(4)
>>> result = td * tensor
在此,`*` 运算符逐元素应用于 td 中的每个张量和提供的张量。该张量被广播以匹配 Tensordict 中每个张量的形状。
示例 3:Tensordict-Scalar 操作¶
>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict(
... a=torch.randn(3, 4),
... b=torch.zeros(3, 4, 5),
... c=torch.ones(3, 4, 5, 6),
... batch_size=(3, 4),
... )
>>> scalar = 2.0
>>> result = td * scalar
在这种情况下,`*` 运算符逐元素应用于 td 中的每个张量和提供的标量。
广播规则¶
当在 Tensordict 和张量/标量之间执行逐点操作时,张量/标量会被广播以匹配 Tensordict 中每个张量的形状:张量首先在左侧广播以匹配 tensordict 的形状,然后单独在右侧广播以匹配张量的形状。如果您将 `TensorDict` 视为单个张量实例,则这遵循 PyTorch 中使用的标准广播规则。
例如,如果您有一个形状为 `(3, 4)` 的 TensorDict,并将其乘以形状为 `(4,)` 的张量,则该张量将在应用操作之前被广播到形状 `(3, 4)`。如果 tensordict 包含一个形状为 `(3, 4, 5)` 的张量,则用于乘法的张量将在右侧广播到形状 `(3, 4, 5)` 以进行该乘法。
如果逐点操作跨多个 tensordict 执行并且它们的批次大小不同,它们将被广播到公共形状。
逐点操作的效率¶
在可能的情况下,将使用 `torch._foreach_
处理缺失条目¶
在两个 Tensordict 之间执行逐点操作时,它们必须具有相同的键。一些操作,例如 `add()`,有一个 `default` 关键字参数,可用于与具有独占条目的 tensordict 进行操作。如果 `default=None`(默认值),则两个 Tensordict 必须具有完全匹配的键集。如果 `default="intersection"`,则仅考虑交集的键集,其他键将被忽略。在所有其他情况下,`default` 将用于操作两侧的所有缺失条目。
实用工具¶
|
将张量向右扩展以匹配另一个张量形状。 |
|
将张量向右扩展以匹配所需形状。 |
|
测试 `input` `dim` 中的 `key` 的每个元素是否也存在于 `reference` 中。 |
|
在指定维度上移除 `key` 中重复的索引。 |
|
获取捕获非张量堆栈的当前设置。 |
|
密集堆叠一个列表的 `TensorDictBase` 对象(或 `LazyStackedTensorDict`),前提是它们具有相同的结构。 |
|
|
|
检查数据对象或类型是否是 tensordict 库中的张量容器。 |
|
如果为选定的方法使用懒惰转换,则返回 `True`。 |
|
从关键字参数或输入字典返回一个创建的 TensorDict。 |
|
将 tensordicts 合并在一起。 |
|
使用常量值沿批次维度填充 tensordict 中的所有张量,并返回一个新的 tensordict。 |
|
填充 tensordict 列表,以便它们可以以连续格式堆叠在一起。 |
将 TensorDict repr 解析为 TensorDict。 |
|
一个上下文管理器或装饰器,用于控制是否应将相同的非张量数据堆叠到单个 NonTensorData 对象或 NonTensorStack 中。 |
|
|
将某些方法的行为设置为懒惰转换。 |
|
用于控制 TensorDict 中列表处理行为的上下文管理器和装饰器。 |
|
检索 TensorDict 中列表到堆栈转换的当前设置。 |