from_dataclass¶
- class tensordict.from_dataclass(obj: Any, *, dest_cls: Optional[Type] = None, auto_batch_size: bool = False, batch_dims: Optional[int] = None, batch_size: Optional[Size] = None, frozen: bool = False, autocast: bool = False, nocast: bool = False, inplace: bool = False, shadow: bool = False, tensor_only: bool = False, device: Optional[device] = None)¶
将 dataclass 实例或类型分别转换为 tensorclass 实例或类型。
此函数接受一个 dataclass 实例或 dataclass 类型,并将其转换为 tensor 兼容的类,还可以选择性地应用各种配置,例如自动批处理、不可变性和类型转换。
- 参数:
obj (Any) – 要转换的 dataclass 实例或类型。如果提供了类型,则返回新类。
- 关键字参数:
dest_cls (tensorclass, optional) – 用于映射数据的 tensorclass 类型。如果未提供,则创建新类。如果
obj
是一个类型,则此参数无效。auto_batch_size (bool, optional) – 如果为
True
,将自动确定并应用批次大小到结果对象。默认为False
。batch_dims (int, optional) – 如果 auto_batch_size 为
True
,则定义输出 tensordict 应具有的维度数。默认为None
(每个级别完全批次大小)。batch_size (torch.Size, optional) – TensorDict 的批次大小。默认为
None
。frozen (bool, optional) – 如果为
True
,则结果类或实例将是不可变的。默认为False
。autocast (bool, optional) – 如果为
True
,则为结果类或实例启用自动类型转换。默认为False
。nocast (bool, optional) – 如果为
True
,则禁用结果类或实例的任何类型转换。默认为False
。tensor_only (bool, optional) – 如果为
True
,则预期 tensorclass 中的所有项都将是张量实例(张量兼容,因为非张量数据会被尽可能转换为张量)。这可以带来显著的速度提升,但会牺牲与非张量数据的灵活交互。默认为False
。inplace (bool, optional) – 如果为
True
,则将就地修改提供的 dataclass 类型。默认为False
。如果提供了实例,则此参数无效。device (torch.device, optional) – 将创建 TensorDict 的设备。默认为
None
。shadow (bool, optional) – 禁用字段名与 TensorDict 保留属性的验证。请谨慎使用,这可能会导致意外后果。默认为 False。
- 返回:
一个派生自提供的 dataclass 的 tensor 兼容类或实例。
- 抛出:
TypeError – 如果提供的输入不是 dataclass 实例或类型。
示例
>>> from dataclasses import dataclass >>> import torch >>> from tensordict.tensorclass import from_dataclass >>> >>> @dataclass >>> class X: ... a: int ... b: torch.Tensor ... >>> x = X(0, 0) >>> x2 = from_dataclass(x) >>> print(x2) X( a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False) >>> X2 = from_dataclass(X, autocast=True) >>> print(X2(a=0, b=0)) X( a=NonTensorData(data=0, batch_size=torch.Size([]), device=None), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False)
注意
如果提供了 dataclass 类型,则会返回一个带有指定配置的新类。如果提供了 dataclass 实例,则会返回 tensor 兼容类的新实例。auto_batch_size、frozen、autocast 和 nocast 选项允许灵活配置结果类或实例。
警告
虽然
from_dataclass()
默认返回TensorDict
实例,但此方法将返回一个 tensorclass 实例或类型。