快捷方式

from_dataclass

class tensordict.from_dataclass(obj: Any, *, dest_cls: Optional[Type] = None, auto_batch_size: bool = False, batch_dims: Optional[int] = None, batch_size: Optional[Size] = None, frozen: bool = False, autocast: bool = False, nocast: bool = False, inplace: bool = False, shadow: bool = False, tensor_only: bool = False, device: Optional[device] = None)

将 dataclass 实例或类型分别转换为 tensorclass 实例或类型。

此函数接收一个 dataclass 实例或 dataclass 类型,并将其转换为 tensor 兼容的类,可选地应用各种配置,例如自动批处理、不可变性和类型转换。

参数:

obj (Any) – 要转换的 dataclass 实例或类型。如果提供的是类型,则返回一个新类。

关键字参数:
  • dest_cls (tensorclass, optional) – 用于映射数据的 tensorclass 类型。如果未提供,则创建一个新类。如果 obj 是一个类型,则此参数无效。

  • auto_batch_size (bool, optional) – 如果为 True,则自动确定并应用批处理大小到结果对象。默认为 False

  • batch_dims (int, optional) – 如果 auto_batch_size 为 True,则定义输出 tensordict 的维度数。默认为 None(每个级别的完整批处理大小)。

  • batch_size (torch.Size, optional) – TensorDict 的批处理大小。默认为 None

  • frozen (bool, optional) – 如果为 True,则生成的类或实例将是不可变的。默认为 False

  • autocast (bool, optional) – 如果为 True,则为生成的类或实例启用自动类型转换。默认为 False

  • nocast (bool, optional) – 如果为 True,则禁用对生成的类或实例的任何类型转换。默认为 False

  • tensor_only (bool, optional) – 如果为 True,则期望 tensorclass 中的所有项都是 tensor 实例(tensor 兼容,因为非 tensor 数据会被尽可能转换为 tensor)。这可以带来显著的速度提升,但会牺牲与非 tensor 数据的灵活交互。默认为 False

  • inplace (bool, optional) – 如果为 True,则将修改传入的 dataclass 类型。默认为 False。如果提供的是实例,则此参数无效。

  • device (torch.device, optional) – 创建 TensorDict 的设备。默认为 None

  • shadow (bool, optional) – 禁用字段名与 TensorDict 保留属性的验证。请谨慎使用,因为它可能导致意外后果。默认为 False。

返回:

派生自提供的 dataclass 的 tensor 兼容类或实例。

抛出:

TypeError – 如果提供的输入不是 dataclass 实例或类型。

示例

>>> from dataclasses import dataclass
>>> import torch
>>> from tensordict.tensorclass import from_dataclass
>>>
>>> @dataclass
>>> class X:
...     a: int
...     b: torch.Tensor
...
>>> x = X(0, 0)
>>> x2 = from_dataclass(x)
>>> print(x2)
X(
    a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> X2 = from_dataclass(X, autocast=True)
>>> print(X2(a=0, b=0))
X(
    a=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
    b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

警告

from_dataclass() 默认将返回一个 TensorDict 实例,此方法将返回一个 tensorclass 实例或类型。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源