快捷方式

tensorclass

The @tensorclass decorator helps you build custom classes that inherit the behaviour from TensorDict while being able to restrict the possible entries to a predefined set or implement custom methods for your class.

Like TensorDict, @tensorclass supports nesting, indexing, reshaping, item assignment. It also supports tensor operations like clone, squeeze, torch.cat, split and many more. @tensorclass allows non-tensor entries, however all the tensor operations are strictly restricted to tensor attributes.

One needs to implement their custom methods for non-tensor data. It is important to note that @tensorclass does not enforce strict type matching

>>> from __future__ import annotations
>>> from tensordict.prototype import tensorclass
>>> import torch
>>> from torch import nn
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     floatdata: torch.Tensor
...     intdata: torch.Tensor
...     non_tensordata: str
...     nested: Optional[MyData] = None
...
...     def check_nested(self):
...         assert self.nested is not None
>>>
>>> data = MyData(
...   floatdata=torch.randn(3, 4, 5),
...   intdata=torch.randint(10, (3, 4, 1)),
...   non_tensordata="test",
...   batch_size=[3, 4]
... )
>>> print("data:", data)
data: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=None,
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)
>>> data.nested = MyData(
...     floatdata = torch.randn(3, 4, 5),
...     intdata=torch.randint(10, (3, 4, 1)),
...     non_tensordata="nested_test",
...     batch_size=[3, 4]
... )
>>> print("nested:", data)
nested: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=MyData(
      floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([3, 4]),
      device=None,
      is_shared=False),
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)

As it is the case with TensorDict, from v0.4 if the batch size is omitted it is considered as empty.

If a non-empty batch-size is provided, @tensorclass supports indexing. Internally the tensor objects gets indexed, however the non-tensor data remains the same

>>> print("indexed:", data[:2])
indexed: MyData(
   floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test',
   nested=MyData(
      floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([2, 4]),
      device=None,
      is_shared=False),
   batch_size=torch.Size([2, 4]),
   device=None,
   is_shared=False)

@tensorclass also supports setting and resetting attributes, even for nested objects.

>>> data.non_tensordata = "test_changed"
>>> print("data.non_tensordata: ", repr(data.non_tensordata))
data.non_tensordata: 'test_changed'

>>> data.floatdata = torch.ones(3, 4, 5)
>>> print("data.floatdata:", data.floatdata)
data.floatdata: tensor([[[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]]])

>>> # Changing nested tensor data
>>> data.nested.non_tensordata = "nested_test_changed"
>>> print("data.nested.non_tensordata:", repr(data.nested.non_tensordata))
data.nested.non_tensordata: 'nested_test_changed'

@tensorclass supports multiple torch operations over the shape and device of its content, such as stack, cat, reshape or to(device). To get a full list of the supported operations, check the tensordict documentation.

Here is an example

>>> data2 = data.clone()
>>> cat_tc = torch.cat([data, data2], 0)
>>> print("Concatenated data:", catted_tc)
Concatenated data: MyData(
   floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test_changed',
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
       intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
       non_tensordata='nested_test_changed',
       nested=None,
       batch_size=torch.Size([6, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([6, 4]),
   device=None,
   is_shared=False)

序列化

Saving a tensorclass instance can be achieved with the memmap method. The saving strategy is as follows: tensor data will be saved using memory-mapped tensors, and non-tensor data that can be serialized using a json format will be saved as such. Other data types will be saved using save(), which relies on pickle.

Deserializing a tensorclass can be done via load_memmap(). The instance created will have the same type as the one saved provided that the tensorclass is available in the working environment

>>> data.memmap("path/to/saved/directory")
>>> data_loaded = TensorDict.load_memmap("path/to/saved/directory")
>>> assert isinstance(data_loaded, type(data))

Edge cases

@tensorclass supports equality and inequality operators, even for nested objects. Note that the non-tensor/ meta data is not validated. This will return a tensor class object with boolean values for tensor attributes and None for non-tensor attributes

Here is an example

>>> print(data == data2)
MyData(
   floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
   intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
   non_tensordata=None,
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
       intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
       non_tensordata=None,
       nested=None,
       batch_size=torch.Size([3, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

@tensorclass supports setting an item. However, while setting an item the identity check of non-tensor / meta data is done instead of equality to avoid performance issues. User needs to make sure that the non-tensor data of an item matches with the object to avoid discrepancies.

Here is an example

While setting an item with different non_tensor data, a UserWarning will be thrown

>>> data2.non_tensordata = "test_new"
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours

Even though @tensorclass supports torch functions like cat() and stack(), the non-tensor / meta data is not validated. The torch operation is performed on the tensor data and while returning the output, the non-tensor / meta data of the first tensor class object is considered. User needs to make sure that all the list of tensor class objects have the same non-tensor data to avoid discrepancies

Here is an example

>>> data2.non_tensordata = "test_new"
>>> stack_tc = torch.cat([data, data2], dim=0)
>>> print(stack_tc)
MyData(
    floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
    non_tensordata='test',
    nested=MyData(
        floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        non_tensordata='nested_test',
        nested=None,
        batch_size=torch.Size([2, 3, 4]),
        device=None,
        is_shared=False),
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False)

@tensorclass also supports pre-allocation, you can initialize the object with attributes being None and later set them. Note that while initializing, internally the None attributes will be saved as non-tensor / meta data and while resetting, based on the type of the value of the attribute, it will be saved as either tensor data or non-tensor / meta data

Here is an example

>>> @tensorclass
... class MyClass:
...   X: Any
...   y: Any

>>> data = MyClass(X=None, y=None, batch_size = [3,4])
>>> data.X = torch.ones(3, 4, 5)
>>> data.y = "testing"
>>> print(data)
MyClass(
   X=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   y='testing',
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

tensorclass([cls, autocast, frozen, nocast, ...])

A decorator to create tensorclass classes.

TensorClass()

TensorClass is the inheritance-based version of the @tensorclass decorator.

NonTensorData(data[, _metadata, ...])

MetaData(data[, _metadata, _is_non_tensor, ...])

NonTensorStack(*args, **kwargs)

A thin wrapper around LazyStackedTensorDict to make stack on non-tensor data easily recognizable.

from_dataclass(obj, *[, dest_cls, ...])

Converts a dataclass instance or a type into a tensorclass instance or type, respectively.

Auto-casting

警告

Auto-casting is an experimental feature and subject to changes in the future. Compatibility with python<=3.9 is limited.

@tensorclass partially supports auto-casting as an experimental feature. Methods such as __setattr__, update, update_ and from_dict will attempt to cast type-annotated entries to the desired TensorDict / tensorclass instance (except in cases detailed below). For instance, following code will cast the td dictionary to a TensorDict and the tc entry to a MyClass instance

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> assert isinstance(obj.tc, TensorDict)
>>> assert isinstance(obj.td, MyClass)

注意

Type annotated items that include an typing.Optional or typing.Union will not be compatible with auto-casting, but other items in the tensorclass will

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     tc_autocast: MyClass = None
...     tc_not_autocast: Optional[MyClass] = None
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     tc_autocast={"tensor": torch.randn(())},
...     tc_not_autocast={"tensor": torch.randn(())},
... )
>>> assert isinstance(obj.tc_autocast, MyClass)
>>> # because the type is Optional or Union, auto-casting is disabled for
>>> # that variable.
>>> assert not isinstance(obj.tc_not_autocast, MyClass)

If at least one item in the class is annotated using the type0 | type1 semantic, the whole class auto-casting capabilities are deactivated. Because tensorclass supports non-tensor leaves, setting a dictionary in these cases will lead to setting it as a plain dictionary instead of a tensor collection subclass (TensorDict or tensorclass)

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass | None
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> # tc and td have not been cast
>>> assert isinstance(obj.tc, dict)
>>> assert isinstance(obj.td, dict)

注意

Auto-casting isn’t enabled for leaves (tensors). The reason for this is that this feature isn’t compatible with type annotations that contain the type0 | type1 type hinting semantic, which is widespread. Allowing auto-casting would result in very similar codes to have drastically different behaviours if the type annotation differs only slightly.

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源