快捷方式

TensorClass

class tensordict.TensorClass

TensorClass 是 @tensorclass 装饰器的基于继承的版本。

TensorClass 允许你编写比使用 @tensorclass 装饰器构建的 dataclass 更好的类型检查和更 Pythonic 的 dataclass。

示例

>>> from typing import Any
>>> import torch
>>> from tensordict import TensorClass
>>> class Foo(TensorClass):
...     tensor: torch.Tensor
...     non_tensor: Any
...     nested: Any = None
>>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3])
>>> print(foo)
Foo(
    non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
    tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
    nested=None,
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
关键字参数:
  • batch_size (torch.Size, 可选) – TensorDict 的批处理大小。默认为 None

  • device (torch.device, 可选) – 将创建 TensorDict 的设备。默认为 None

  • frozen (bool, 可选) – 如果为 True,则生成的类或实例将是不可变的。默认为 False

  • autocast (bool, 可选) – 如果为 True,则为生成的类或实例启用自动类型转换。默认为 False

  • nocast (bool, 可选) – 如果为 True,则禁用对生成的类或实例的任何类型转换。默认为 False

  • tensor_only (bool, 可选) – 如果为 True,则期望 tensorclass 中的所有项都是 tensor 实例(tensor 兼容,因为非 tensor 数据在可能的情况下会转换为 tensor)。这可以带来显著的速度提升,但会以牺牲与非 tensor 数据的灵活交互为代价。默认为 False

  • shadow (bool, 可选) – 禁用字段名与 TensorDict 的保留属性的验证。请谨慎使用,因为它可能导致意外的后果。默认为 False。

你可以通过两种方式传递布尔关键字参数(“autocast”“nocast”“frozen”“tensor_only”“shadow”):使用

方括号或关键字参数。

示例

>>> class Foo(TensorClass["autocast"]):
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass, autocast=True):  # equivalent
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass["nocast"]):
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass["nocast", "frozen"]):  # multiple keywords can be used
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass, nocast=True):  # equivalent
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass):
...     integer: int
>>> Foo(integer=1).integer
tensor(1)

警告

TensorClass 本身未被装饰为 tensorclass,但其子类将被装饰。这是因为我们无法预测 frozen 参数是否会被设置,如果被设置,它可能会与父类冲突(如果父类不是 frozen 的,子类也不能是 frozen 的)。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源