快捷方式

tensorclass

@tensorclass 装饰器帮助您构建继承自 TensorDict 行为的自定义类,同时能够将可能的条目限制为预定义集合或为您的类实现自定义方法。

TensorDict 一样,@tensorclass 支持嵌套、索引、重塑、项赋值。它还支持张量操作,如 clone, squeeze, torch.cat, split 等。@tensorclass 允许非张量条目,但是所有张量操作都严格限制在张量属性上。

需要为非张量数据实现自定义方法。需要注意的是,@tensorclass 不强制严格的类型匹配

>>> from __future__ import annotations
>>> from tensordict.prototype import tensorclass
>>> import torch
>>> from torch import nn
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     floatdata: torch.Tensor
...     intdata: torch.Tensor
...     non_tensordata: str
...     nested: Optional[MyData] = None
...
...     def check_nested(self):
...         assert self.nested is not None
>>>
>>> data = MyData(
...   floatdata=torch.randn(3, 4, 5),
...   intdata=torch.randint(10, (3, 4, 1)),
...   non_tensordata="test",
...   batch_size=[3, 4]
... )
>>> print("data:", data)
data: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=None,
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)
>>> data.nested = MyData(
...     floatdata = torch.randn(3, 4, 5),
...     intdata=torch.randint(10, (3, 4, 1)),
...     non_tensordata="nested_test",
...     batch_size=[3, 4]
... )
>>> print("nested:", data)
nested: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=MyData(
      floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([3, 4]),
      device=None,
      is_shared=False),
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)

正如 TensorDict 的情况一样,从 v0.4 开始,如果省略批次大小,则认为其为空。

如果提供了非空批次大小,@tensorclass 支持索引。内部会索引张量对象,但是非张量数据保持不变

>>> print("indexed:", data[:2])
indexed: MyData(
   floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test',
   nested=MyData(
      floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([2, 4]),
      device=None,
      is_shared=False),
   batch_size=torch.Size([2, 4]),
   device=None,
   is_shared=False)

@tensorclass 还支持设置和重置属性,即使是嵌套对象。

>>> data.non_tensordata = "test_changed"
>>> print("data.non_tensordata: ", repr(data.non_tensordata))
data.non_tensordata: 'test_changed'

>>> data.floatdata = torch.ones(3, 4, 5)
>>> print("data.floatdata:", data.floatdata)
data.floatdata: tensor([[[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]]])

>>> # Changing nested tensor data
>>> data.nested.non_tensordata = "nested_test_changed"
>>> print("data.nested.non_tensordata:", repr(data.nested.non_tensordata))
data.nested.non_tensordata: 'nested_test_changed'

@tensorclass 支持对其内容进行形状和设备的多个 torch 操作,例如 stack, cat, reshapeto(device)。要获取支持操作的完整列表,请参阅 tensordict 文档。

这是一个例子:

>>> data2 = data.clone()
>>> cat_tc = torch.cat([data, data2], 0)
>>> print("Concatenated data:", catted_tc)
Concatenated data: MyData(
   floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test_changed',
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
       intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
       non_tensordata='nested_test_changed',
       nested=None,
       batch_size=torch.Size([6, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([6, 4]),
   device=None,
   is_shared=False)

序列化

保存 tensorclass 实例可以通过 memmap 方法实现。保存策略如下:张量数据将使用内存映射张量保存,而可以使用 json 格式序列化的非张量数据将以此方式保存。其他数据类型将使用 save() 保存,该方法依赖于 pickle

反序列化 tensorclass 可以通过 load_memmap() 完成。创建的实例将具有与保存的实例相同的类型,前提是 tensorclass 在工作环境中可用

>>> data.memmap("path/to/saved/directory")
>>> data_loaded = TensorDict.load_memmap("path/to/saved/directory")
>>> assert isinstance(data_loaded, type(data))

边缘情况

@tensorclass 支持相等和不等运算符,即使是嵌套对象。请注意,非张量/元数据未经过验证。这将返回一个具有布尔值(用于张量属性)和 None(用于非张量属性)的张量类对象

这是一个例子:

>>> print(data == data2)
MyData(
   floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
   intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
   non_tensordata=None,
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
       intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
       non_tensordata=None,
       nested=None,
       batch_size=torch.Size([3, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

@tensorclass 支持设置一个项。但是,在设置项时,会进行非张量/元数据的标识检查而不是相等性检查,以避免性能问题。用户需要确保项的非张量数据与对象匹配,以避免差异。

这是一个例子:

在设置具有不同 non_tensor 数据的项时,会抛出 UserWarning

>>> data2.non_tensordata = "test_new"
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours

尽管 @tensorclass 支持 cat()stack() 等 torch 函数,但非张量/元数据不会被验证。torch 操作会在张量数据上执行,并在返回输出时,会考虑第一个 tensor class 对象的非张量/元数据。用户需要确保所有 tensor class 对象列表具有相同的非张量数据,以避免差异

这是一个例子:

>>> data2.non_tensordata = "test_new"
>>> stack_tc = torch.cat([data, data2], dim=0)
>>> print(stack_tc)
MyData(
    floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
    non_tensordata='test',
    nested=MyData(
        floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        non_tensordata='nested_test',
        nested=None,
        batch_size=torch.Size([2, 3, 4]),
        device=None,
        is_shared=False),
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False)

@tensorclass 还支持预分配,您可以将属性初始化为 None,然后稍后设置它们。请注意,在初始化时,内部的 None 属性将保存为非张量/元数据,而在重置时,根据属性值的类型,它将被保存为张量数据或非张量/元数据

这是一个例子:

>>> @tensorclass
... class MyClass:
...   X: Any
...   y: Any

>>> data = MyClass(X=None, y=None, batch_size = [3,4])
>>> data.X = torch.ones(3, 4, 5)
>>> data.y = "testing"
>>> print(data)
MyClass(
   X=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   y='testing',
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

tensorclass([cls, autocast, frozen, nocast, ...])

一个用于创建 tensorclass 类的装饰器。

TensorClass(*args, **kwargs)

TensorClass 是 @tensorclass 装饰器的基于继承的版本。

NonTensorData(data[, _metadata, ...])

MetaData(data[, _metadata, _is_non_tensor, ...])

NonTensorStack(*args, **kwargs)

LazyStackedTensorDict 的一个薄包装器,用于轻松识别非张量数据的堆叠。

from_dataclass(obj, *[, dest_cls, ...])

将 dataclass 实例或类型分别转换为 tensorclass 实例或类型。

自动类型转换

警告

自动类型转换是一项实验性功能,未来可能会发生变化。与 python<=3.9 的兼容性有限。

@tensorclass 作为一项实验性功能部分支持自动类型转换。__setattr__, update, update_from_dict 等方法将尝试将类型注解的条目转换为所需的 TensorDict / tensorclass 实例(除非发生如下所述的情况)。例如,以下代码将把 td 字典转换为 TensorDict,并将 tc 条目转换为 MyClass 实例

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> assert isinstance(obj.tc, TensorDict)
>>> assert isinstance(obj.td, MyClass)

注意

包含 typing.Optionaltyping.Union 的类型注解条目将与自动类型转换不兼容,但 tensorclass 中的其他条目将兼容

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     tc_autocast: MyClass = None
...     tc_not_autocast: Optional[MyClass] = None
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     tc_autocast={"tensor": torch.randn(())},
...     tc_not_autocast={"tensor": torch.randn(())},
... )
>>> assert isinstance(obj.tc_autocast, MyClass)
>>> # because the type is Optional or Union, auto-casting is disabled for
>>> # that variable.
>>> assert not isinstance(obj.tc_not_autocast, MyClass)

如果类中的至少一个条目使用 type0 | type1 语义进行注解,则整个类的自动类型转换功能将被禁用。因为 tensorclass 支持非张量叶子,在这种情况下设置字典将导致将其设置为普通字典而不是 tensor collection 子类(TensorDicttensorclass

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass | None
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> # tc and td have not been cast
>>> assert isinstance(obj.tc, dict)
>>> assert isinstance(obj.td, dict)

注意

自动类型转换未对叶子(张量)启用。原因是此功能与包含 type0 | type1 类型提示语义的类型注解不兼容,后者很普遍。允许自动类型转换将导致非常相似的代码,如果类型注解仅有细微差别,行为就会有很大差异。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源