TensorClass¶
- class tensordict.TensorClass¶
TensorClass 是 @tensorclass 装饰器的基于继承的版本。
TensorClass 允许你编写比使用 @tensorclass 装饰器构建的 dataclass 更好的类型检查和更 Pythonic 的 dataclass。
示例
>>> from typing import Any >>> import torch >>> from tensordict import TensorClass >>> class Foo(TensorClass): ... tensor: torch.Tensor ... non_tensor: Any ... nested: Any = None >>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3]) >>> print(foo) Foo( non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None), tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested=None, batch_size=torch.Size([3]), device=None, is_shared=False)
- 关键字参数:
batch_size (torch.Size, 可选) – TensorDict 的批处理大小。默认为
None
。device (torch.device, 可选) – 将创建 TensorDict 的设备。默认为
None
。frozen (bool, 可选) – 如果为
True
,则生成的类或实例将是不可变的。默认为False
。autocast (bool, 可选) – 如果为
True
,则为生成的类或实例启用自动类型转换。默认为False
。nocast (bool, 可选) – 如果为
True
,则禁用对生成的类或实例的任何类型转换。默认为False
。tensor_only (bool, 可选) – 如果为
True
,则期望 tensorclass 中的所有项都是 tensor 实例(tensor 兼容,因为非 tensor 数据在可能的情况下会转换为 tensor)。这可以带来显著的速度提升,但会以牺牲与非 tensor 数据的灵活交互为代价。默认为False
。shadow (bool, 可选) – 禁用字段名与 TensorDict 的保留属性的验证。请谨慎使用,因为它可能导致意外的后果。默认为 False。
- 你可以通过两种方式传递布尔关键字参数(“autocast”、“nocast”、“frozen”、“tensor_only”、“shadow”):使用
方括号或关键字参数。
示例
>>> class Foo(TensorClass["autocast"]): ... integer: int >>> Foo(integer=torch.ones(())).integer 1 >>> class Foo(TensorClass, autocast=True): # equivalent ... integer: int >>> Foo(integer=torch.ones(())).integer 1 >>> class Foo(TensorClass["nocast"]): ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass["nocast", "frozen"]): # multiple keywords can be used ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass, nocast=True): # equivalent ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass): ... integer: int >>> Foo(integer=1).integer tensor(1)
警告
TensorClass 本身未被装饰为 tensorclass,但其子类将被装饰。这是因为我们无法预测 frozen 参数是否会被设置,如果被设置,它可能会与父类冲突(如果父类不是 frozen 的,子类也不能是 frozen 的)。