from_dataclass¶
- class tensordict.from_dataclass(obj: Any, *, dest_cls: Optional[Type] = None, auto_batch_size: bool = False, batch_dims: Optional[int] = None, batch_size: Optional[Size] = None, frozen: bool = False, autocast: bool = False, nocast: bool = False, inplace: bool = False, shadow: bool = False, tensor_only: bool = False, device: Optional[device] = None)¶
将 dataclass 实例或类型分别转换为 tensorclass 实例或类型。
此函数接收一个 dataclass 实例或 dataclass 类型,并将其转换为 tensor 兼容的类,可选地应用各种配置,例如自动批处理、不可变性和类型转换。
- 参数:
obj (Any) – 要转换的 dataclass 实例或类型。如果提供的是类型,则返回一个新类。
- 关键字参数:
dest_cls (tensorclass, optional) – 用于映射数据的 tensorclass 类型。如果未提供,则创建一个新类。如果
obj
是一个类型,则此参数无效。auto_batch_size (bool, optional) – 如果为
True
,则自动确定并应用批处理大小到结果对象。默认为False
。batch_dims (int, optional) – 如果 auto_batch_size 为
True
,则定义输出 tensordict 的维度数。默认为None
(每个级别的完整批处理大小)。batch_size (torch.Size, optional) – TensorDict 的批处理大小。默认为
None
。frozen (bool, optional) – 如果为
True
,则生成的类或实例将是不可变的。默认为False
。autocast (bool, optional) – 如果为
True
,则为生成的类或实例启用自动类型转换。默认为False
。nocast (bool, optional) – 如果为
True
,则禁用对生成的类或实例的任何类型转换。默认为False
。tensor_only (bool, optional) – 如果为
True
,则期望 tensorclass 中的所有项都是 tensor 实例(tensor 兼容,因为非 tensor 数据会被尽可能转换为 tensor)。这可以带来显著的速度提升,但会牺牲与非 tensor 数据的灵活交互。默认为False
。inplace (bool, optional) – 如果为
True
,则将修改传入的 dataclass 类型。默认为False
。如果提供的是实例,则此参数无效。device (torch.device, optional) – 创建 TensorDict 的设备。默认为
None
。shadow (bool, optional) – 禁用字段名与 TensorDict 保留属性的验证。请谨慎使用,因为它可能导致意外后果。默认为 False。
- 返回:
派生自提供的 dataclass 的 tensor 兼容类或实例。
- 抛出:
TypeError – 如果提供的输入不是 dataclass 实例或类型。
示例
>>> from dataclasses import dataclass >>> import torch >>> from tensordict.tensorclass import from_dataclass >>> >>> @dataclass >>> class X: ... a: int ... b: torch.Tensor ... >>> x = X(0, 0) >>> x2 = from_dataclass(x) >>> print(x2) X( a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False) >>> X2 = from_dataclass(X, autocast=True) >>> print(X2(a=0, b=0)) X( a=NonTensorData(data=0, batch_size=torch.Size([]), device=None), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False)
警告
而
from_dataclass()
默认将返回一个TensorDict
实例,此方法将返回一个 tensorclass 实例或类型。