tensorclass¶
- class tensordict.tensorclass(cls: Optional[T] = None, /, *, autocast: bool = False, frozen: bool = False, nocast: bool = False, shadow: bool = False, tensor_only: bool = False)¶
一个用于创建
tensorclass
类的装饰器。tensorclass
类是专门化的dataclasses.dataclass()
实例,它们可以开箱即用地执行一些预定义的张量操作,例如索引、项赋值、重塑、转换为设备或存储等等。- 关键字参数:
autocast (bool, optional) – 如果为
True
,则在设置参数时将强制执行类型。此参数与autocast
互斥(两者不能同时为 True)。默认为False
。frozen (bool, optional) – 如果为
True
,则无法修改 tensorclass 的内容。此参数提供给 dataclass 兼容性,可以通过类构造函数中的 lock 参数获得类似的行为。默认为False
。nocast (bool, optional) – 如果为
True
,则不会将张量兼容类型(如int
、np.ndarray
等)转换为张量类型。此参数与autocast
互斥(两者不能同时为 True)。默认为False
。shadow (bool, optional) – 禁用字段名与 TensorDict 的保留属性的验证。请谨慎使用,因为它可能导致意外后果。默认为 False。
tensor_only (bool, optional) – 如果为
True
,则预计 tensorclass 中的所有项都将是张量实例(张量兼容,因为非张量数据会被尽可能地转换为张量)。这可以带来显著的速度提升,但会牺牲与非张量数据的灵活交互。默认为False
。
tensorclass 可以带参数或不带参数使用
示例
>>> @tensorclass ... class X: ... y: int >>> X(torch.ones(())).y tensor(1.) >>> @tensorclass(autocast=False) ... class X: ... y: int >>> X(torch.ones(())).y tensor(1.) >>> @tensorclass(autocast=True) ... class X: ... y: int >>> X(torch.ones(())).y 1 >>> @tensorclass(nocast=True) ... class X: ... y: Any >>> X(1).y 1 >>> @tensorclass(nocast=False) ... class X: ... y: Any >>> X(1).y tensor(1)
示例
>>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass ... class MyData: ... X: torch.Tensor ... y: torch.Tensor ... z: str ... def expand_and_mask(self): ... X = self.X.unsqueeze(-1).expand_as(self.y) ... X = X[self.y] ... return X ... >>> data = MyData( ... X=torch.ones(3, 4, 1), ... y=torch.zeros(3, 4, 2, 2, dtype=torch.bool), ... z="test" ... batch_size=[3, 4]) >>> print(data) MyData( X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32), y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool), z="test" batch_size=[3, 4], device=None, is_shared=False) >>> print(data.expand_and_mask()) tensor([])
- 也可以将 tensorclass 实例嵌套在彼此之中
示例: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # 尽管数据存储为 TensorDict,但类型提示有助于我们 >>> # 将数据正确转换为正确的类型 >>> assert isinstance(nesting_data.nested, type(data))