快捷方式

tensorclass

class tensordict.tensorclass(cls: Optional[T] = None, /, *, autocast: bool = False, frozen: bool = False, nocast: bool = False, shadow: bool = False, tensor_only: bool = False)

一个用于创建 tensorclass 类的装饰器。

tensorclass 类是专门化的 dataclasses.dataclass() 实例,它们可以开箱即用地执行一些预定义的张量操作,例如索引、项赋值、重塑、转换为设备或存储等等。

关键字参数:
  • autocast (bool, optional) – 如果为 True,则在设置参数时将强制执行类型。此参数与 autocast 互斥(两者不能同时为 True)。默认为 False

  • frozen (bool, optional) – 如果为 True,则无法修改 tensorclass 的内容。此参数提供给 dataclass 兼容性,可以通过类构造函数中的 lock 参数获得类似的行为。默认为 False

  • nocast (bool, optional) – 如果为 True,则不会将张量兼容类型(如 intnp.ndarray 等)转换为张量类型。此参数与 autocast 互斥(两者不能同时为 True)。默认为 False

  • shadow (bool, optional) – 禁用字段名与 TensorDict 的保留属性的验证。请谨慎使用,因为它可能导致意外后果。默认为 False。

  • tensor_only (bool, optional) – 如果为 True,则预计 tensorclass 中的所有项都将是张量实例(张量兼容,因为非张量数据会被尽可能地转换为张量)。这可以带来显著的速度提升,但会牺牲与非张量数据的灵活交互。默认为 False

tensorclass 可以带参数或不带参数使用

示例

>>> @tensorclass
... class X:
...     y: int
>>> X(torch.ones(())).y
tensor(1.)
>>> @tensorclass(autocast=False)
... class X:
...     y: int
>>> X(torch.ones(())).y
tensor(1.)
>>> @tensorclass(autocast=True)
... class X:
...     y: int
>>> X(torch.ones(())).y
1
>>> @tensorclass(nocast=True)
... class X:
...     y: Any
>>> X(1).y
1
>>> @tensorclass(nocast=False)
... class X:
...     y: Any
>>> X(1).y
tensor(1)

示例

>>> from tensordict import tensorclass
>>> import torch
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     X: torch.Tensor
...     y: torch.Tensor
...     z: str
...     def expand_and_mask(self):
...         X = self.X.unsqueeze(-1).expand_as(self.y)
...         X = X[self.y]
...         return X
...
>>> data = MyData(
...     X=torch.ones(3, 4, 1),
...     y=torch.zeros(3, 4, 2, 2, dtype=torch.bool),
...     z="test"
...     batch_size=[3, 4])
>>> print(data)
MyData(
    X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32),
    y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool),
    z="test"
    batch_size=[3, 4],
    device=None,
    is_shared=False)
>>> print(data.expand_and_mask())
tensor([])
也可以将 tensorclass 实例嵌套在彼此之中

示例: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # 尽管数据存储为 TensorDict,但类型提示有助于我们 >>> # 将数据正确转换为正确的类型 >>> assert isinstance(nesting_data.nested, type(data))

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源