快捷方式

TensorClass

class tensordict.TensorClass(*args, **kwargs)

TensorClass 是 `@tensorclass` 装饰器的基于继承的版本。

TensorClass 允许你编写比使用 `@tensorclass` 装饰器构建的 dataclasses 具有更好的类型检查和更具 Pythonic 的代码。

示例

>>> from typing import Any
>>> import torch
>>> from tensordict import TensorClass
>>> class Foo(TensorClass):
...     tensor: torch.Tensor
...     non_tensor: Any
...     nested: Any = None
>>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3])
>>> print(foo)
Foo(
    non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
    tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
    nested=None,
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
关键字参数:
  • batch_size (torch.Size, optional) – TensorDict 的批次大小。默认为 None

  • device (torch.device, optional) – 将创建 TensorDict 的设备。默认为 None

  • frozen (bool, optional) – 如果为 True,则生成的类或实例将是不可变的。默认为 False

  • autocast (bool, optional) – 如果为 True,则为生成的类或实例启用自动类型转换。默认为 False

  • nocast (bool, optional) – 如果为 True,则禁用为生成的类或实例进行的任何类型转换。默认为 False

  • tensor_only (bool, optional) – 如果为 True,则预期 tensorclass 中的所有项都将是张量实例(张量兼容,因为非张量数据会被尽可能转换为张量)。这可以带来显著的速度提升,但会牺牲与非张量数据的灵活交互。默认为 False

  • shadow (bool, optional) – 禁用字段名与 TensorDict 保留属性的验证。请谨慎使用,这可能会导致意外后果。默认为 False。

你可以通过两种方式传递布尔关键字参数(“autocast”“nocast”“frozen”“tensor_only”“shadow”):使用

方括号或关键字参数。

示例

>>> class Foo(TensorClass["autocast"]):
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass, autocast=True):  # equivalent
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass["nocast"]):
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass["nocast", "frozen"]):  # multiple keywords can be used
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass, nocast=True):  # equivalent
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass):
...     integer: int
>>> Foo(integer=1).integer
tensor(1)

警告

TensorClass 本身没有被装饰为 tensorclass,但其子类将会。这是因为我们无法预知 `frozen` 参数是否会被设置,如果设置了,它可能与父类冲突(子类不能是 frozen 的,如果父类不是)。

dumps(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any

将tensordict保存到磁盘。

此函数是 `memmap()` 的代理。

from_tensordict(tensordict: TensorDictBase, non_tensordict: Optional[dict] = None, safe: bool = True) Any

用于实例化新张量类对象的张量类包装器。

参数:
  • tensordict (TensorDictBase) – 张量类型的字典

  • non_tensordict (dict) – 包含非张量和嵌套张量类对象的字典

  • safe (bool) – 如果 tensordict 不是 TensorDictBase 实例,则是否引发错误

get(key: NestedKey, *args, **kwargs)

获取输入键对应的存储值。

参数:
  • key (str, str 的元组) – 要查询的键。如果是 str 的元组,则等同于链式调用 getattr。

  • default – 如果在张量类中找不到键,则返回默认值。

返回:

存储在输入键下的值

classmethod load(prefix: str | pathlib.Path, *args, **kwargs) Any

从磁盘加载 tensordict。

此类方法是 `load_memmap()` 的代理。

load_(prefix: str | pathlib.Path, *args, **kwargs)

在当前 tensordict 中从磁盘加载 tensordict。

此类方法是 load_memmap_() 的代理。

classmethod load_memmap(prefix: str | pathlib.Path, device: Optional[device] = None, non_blocking: bool = False, *, out: Optional[TensorDictBase] = None) Any

从磁盘加载内存映射的 tensordict。

参数:
  • prefix (str文件夹路径) – 应从中获取已保存 tensordict 的文件夹路径。

  • device (torch.device等效项, 可选) – 如果提供,数据将异步转换为该设备。支持 `"meta"` 设备,在这种情况下,数据不会被加载,而是创建一组空的 "meta" 张量。这对于在不实际打开任何文件的情况下了解模型大小和结构很有用。

  • non_blocking (bool, 可选) – 如果为 `True`,则在将张量加载到设备后不会调用同步。默认为 `False`。

  • out (TensorDictBase, 可选) – 应将数据写入其中的可选 tensordict。

示例

>>> from tensordict import TensorDict
>>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0)
>>> td.memmap("./saved_td")
>>> td_load = TensorDict.load_memmap("./saved_td")
>>> assert (td == td_load).all()

此方法还允许加载嵌套的 tensordicts。

示例

>>> nested = TensorDict.load_memmap("./saved_td/nested")
>>> assert nested["e"] == 0

tensordict 也可以在“meta”设备上加载,或者作为假张量加载。

示例

>>> import tempfile
>>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}})
>>> with tempfile.TemporaryDirectory() as path:
...     td.save(path)
...     td_load = TensorDict.load_memmap(path, device="meta")
...     print("meta:", td_load)
...     from torch._subclasses import FakeTensorMode
...     with FakeTensorMode():
...         td_load = TensorDict.load_memmap(path)
...         print("fake:", td_load)
meta: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=meta,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=meta,
    is_shared=False)
fake: TensorDict(
    fields={
        a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)

尝试将 state_dict 加载到目标张量类中(原地)。

memmap(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Any

将所有张量写入内存映射的 Tensor 中,并放入新的 tensordict。

参数:
  • prefix (str) – 内存映射张量将存储的目录前缀。目录树结构将模仿 tensordict 的结构。

  • copy_existing (bool) – 如果为 False(默认值),并且 tensordict 中某项已是存储在磁盘上的张量且关联了文件,但未按 prefix 保存到正确位置,则会引发异常。如果为 True,则任何现有张量都将被复制到新位置。

关键字参数:
  • num_threads (int, 可选) – 用于写入 memmap 张量的线程数。默认为 0

  • return_early (bool, 可选) – 如果设置为 Truenum_threads>0,则该方法将返回 tensordict 的一个 future。

  • share_non_tensor (bool, 可选) – 如果设置为 True,则非张量数据将在进程之间共享,并且在单个节点内的任何工作者上进行的写入操作(例如就地更新或设置)将更新所有其他工作者上的值。如果非张量叶子节点数量很多(例如,共享大量非张量数据),这可能会导致 OOM 或类似错误。默认为 False

  • existsok (bool, optional) – 如果为 False,则如果同一路径下已存在张量,将引发异常。默认为 True

然后,Tensordict 被锁定,这意味着任何非就地写入操作(例如重命名、设置或删除条目)都将引发异常。一旦 tensordict 被解锁,内存映射属性将变为 False,因为不能保证跨进程身份。

返回:

返回一个新的 tensordict,其中张量存储在磁盘上(如果 return_early=False),否则返回一个 TensorDictFuture 实例。

注意

以这种方式序列化对于深度嵌套的 tensordicts 来说可能很慢,因此不建议在训练循环中调用此方法。

memmap_(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Any

将所有张量原地写入相应的内存映射张量。

参数:
  • prefix (str) – 内存映射张量将存储的目录前缀。目录树结构将模仿 tensordict 的结构。

  • copy_existing (bool) – 如果为 False(默认值),并且 tensordict 中某项已是存储在磁盘上的张量且关联了文件,但未按 prefix 保存到正确位置,则会引发异常。如果为 True,则任何现有张量都将被复制到新位置。

关键字参数:
  • num_threads (int, 可选) – 用于写入 memmap 张量的线程数。默认为 0

  • return_early (bool, optional) – 如果为 Truenum_threads>0,则方法将返回一个 tensordict 的 future。生成的 tensordict 可以使用 future.result() 进行查询。

  • share_non_tensor (bool, 可选) – 如果设置为 True,则非张量数据将在进程之间共享,并且在单个节点内的任何工作者上进行的写入操作(例如就地更新或设置)将更新所有其他工作者上的值。如果非张量叶子节点数量很多(例如,共享大量非张量数据),这可能会导致 OOM 或类似错误。默认为 False

  • existsok (bool, optional) – 如果为 False,则如果同一路径下已存在张量,将引发异常。默认为 True

然后,Tensordict 被锁定,这意味着任何非就地写入操作(例如重命名、设置或删除条目)都将引发异常。一旦 tensordict 被解锁,内存映射属性将变为 False,因为不能保证跨进程身份。

返回:

如果 return_early=False,则返回 self,否则返回 TensorDictFuture 实例。

注意

以这种方式序列化对于深度嵌套的 tensordicts 来说可能很慢,因此不建议在训练循环中调用此方法。

memmap_like(prefix: Optional[str] = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any

创建一个无内容的内存映射 tensordict,其形状与原始 tensordict 相同。

参数:
  • prefix (str) – 内存映射张量将存储的目录前缀。目录树结构将模仿 tensordict 的结构。

  • copy_existing (bool) – 如果为 False(默认值),并且 tensordict 中某项已是存储在磁盘上的张量且关联了文件,但未按 prefix 保存到正确位置,则会引发异常。如果为 True,则任何现有张量都将被复制到新位置。

关键字参数:
  • num_threads (int, 可选) – 用于写入 memmap 张量的线程数。默认为 0

  • return_early (bool, 可选) – 如果设置为 Truenum_threads>0,则该方法将返回 tensordict 的一个 future。

  • share_non_tensor (bool, 可选) – 如果设置为 True,则非张量数据将在进程之间共享,并且在单个节点内的任何工作者上进行的写入操作(例如就地更新或设置)将更新所有其他工作者上的值。如果非张量叶子节点数量很多(例如,共享大量非张量数据),这可能会导致 OOM 或类似错误。默认为 False

  • existsok (bool, optional) – 如果为 False,则如果同一路径下已存在张量,将引发异常。默认为 True

然后,Tensordict 被锁定,这意味着任何非就地写入操作(例如重命名、设置或删除条目)都将引发异常。一旦 tensordict 被解锁,内存映射属性将变为 False,因为不能保证跨进程身份。

返回:

如果 return_early=False,则创建一个新的 TensorDict 实例,其中数据存储为内存映射张量;否则,创建一个 TensorDictFuture 实例。

注意

这是将一组大型缓冲区写入磁盘的推荐方法,因为 `memmap_()` 将会复制信息,这对于大型内容来说可能会很慢。

示例

>>> td = TensorDict({
...     "a": torch.zeros((3, 64, 64), dtype=torch.uint8),
...     "b": torch.zeros(1, dtype=torch.int64),
... }, batch_size=[]).expand(1_000_000)  # expand does not allocate new memory
>>> buffer = td.memmap_like("/path/to/dataset")
memmap_refresh_()

如果内存映射的 tensordict 具有 saved_path,则刷新其内容。

如果没有任何路径与之关联,此方法将引发异常。

save(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any

将tensordict保存到磁盘。

此函数是 `memmap()` 的代理。

set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)

设置一个新的键值对。

参数:
  • key (str, tuple of str) – 要设置的键名。如果是字符串元组,则等同于链式调用 getattr,然后最后调用 setattr。

  • value (Any) – 要存储在张量类中的值

  • inplace (bool, optional) – 如果为 True,则 set 将尝试就地更新值。如果为 False 或键不存在,则值将简单地写入其目标位置。

返回:

self

state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any]

返回一个 state_dict 字典,可用于保存和加载张量类的数据。

to_tensordict(*, retain_none: Optional[bool] = None) TensorDict

将张量类转换为常规 TensorDict。

复制所有条目。内存映射和共享内存张量将被转换为常规张量。

参数:

retain_none (bool) – 如果 True,则 None 值将被写入 tensordict。否则,它们将被丢弃。默认值:True

返回:

包含与张量类相同值的新的 TensorDict 对象。

unbind(dim: int)

返回沿指定维度解绑的索引张量类实例的元组。

结果张量类实例将共享初始张量类实例的存储。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源