TorchAOBaseTensor¶
- class torchao.utils.TorchAOBaseTensor[源代码]¶
- 一个工具张量子类,提供常用函数
新的张量子类可以继承它来获得所有实用功能
- class MyTensor(TorchAOBaseTensor)
pass
- 这包括
- _get_to_kwargs 可以获取 to 的 kwargs
- 实现了:
implements = MyTensor.implements
@implements(torch.nn.functional.linear): def _(func, types, args, kwargs)
…
- register_layout:
register_layout = MyTensor.register_layout
@register_layout(PlainLayout) class PlainAQTTensorImpl(…)
…
- get_tensor_impl_constructor:
get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor # 在 MyTensor 的构造函数中: tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
- 用于简化张量子类实现的类变量
- tensor_data_names (List[str]): 所有必需的 tensor_data 的名称列表,顺序应匹配
张量子类的 __init__ 列表
- tensor_attribute_names (List[str]): 非 Tensor 属性的名称列表,
顺序应与张量子类的 __init__ 列表匹配,后跟所有 tensor_data_names 参数
optional_tensor_data_names (List[str]): 可选定义此字段以为您实现额外的样板函数,但这在有任何可选的 Tensor 数据属性时是必需的,定义时,这将是可选项的 Tensor 的名称列表 optional_tensor_attribute_names (List[str]): 可选定义此字段以为您实现额外的样板函数,但这在有任何可选的非 Tensor 属性时是必需的,定义时,这将是可选项的属性的名称列表 注意:`__init__` 和 `__new__` 中的参数顺序应与 `tensor_data_names` + `tensor_attribute_names` + `optional_tensor_data_names` (如果存在) + `optional_tensor_attribute_names` (如果存在) 完全匹配。
如果定义了 tensor_data_names 和 tensor_attribute_names,则会添加一些额外的函数,包括: __tensor_flatten__:展平子类化的张量实例,返回一个元组,第一个元素是有效张量数据的名称,
第二个元素是非 Tensor 属性的列表
__tensor_unflatten__:接受一个 `tensor_data_dict`(张量名称到张量的映射)和非张量属性列表,返回子类化张量的新实例 _apply_fn_to_data:接受一个函数(Tensor -> Tensor),将函数应用于所有张量数据并
用转换后的张量数据重新创建一个子类化张量
__repr__:子类化张量实例的字符串表示形式 _same_metadata:返回 cls 实例之间元数据是否相同 __setstate__:加载序列化的张量子类检查点时,它会将旧检查点中保存的新可选张量和张量属性设置为 None,以在将新的可选张量数据或属性添加到张量子类时保持旧检查点的向后兼容性。 PyTorch 操作:torch.Tensor.contiguous ATen 操作:aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (启用 t.to)
示例
- class MyTensor(torch.Tensor)
tensor_data_names = [“a”, “b”] tensor_attribute_names = [“c”, “d”] optional_tensor_data_names = [“e”, “f”] optional_tensor_attribute_names = [“g”, “h”]
- def __new__(
cls, a: Tensor, b: Tensor, c: int, d: str, e: Optional[Tensor] = None, f: Optional[Tensor] = None, g: Optional[int] = None, h: Optional[int] = None,
- ):
pass
- def __init__(
self, a: Tensor, b: Tensor, c: int, d: str e: Optional[Tensor] = None, f: Optional[Tensor] = None, g: Optional[int] = None, h: Optional[int] = None,
- ):
pass
- classmethod get_tensor_impl_constructor(layout_class: Callable) Callable ¶
获取 tensor_class 的 TensorImpl 类构造函数 (TensorImplClass.from_plain),基于 layout_class layout_class 表示 `Layout` 的子类类型,例如 PlainLayout
- 参数:
tensor_class – 张量子类类型
layout_class – `Layout` 的子类类型,例如 PlainLayout
- 返回:
layout_class 的 tensor impl 子类构造函数
- classmethod implements(aten_ops_or_torch_fns)¶
使用此装饰器为 `__torch_dispatch__` 中的 aten 操作(如果用户传入了一个操作列表)或 `__torch_function__` 中的 torch 函数(如果用户传入了一个单一对象)实现一个函数。
- class MyTensor(torch.Tensor)
… implements = classmethod(_implements)
implements = MyTensor.implements
@implements(torch.nn.functional.linear): def _(func, types, args, kwargs)
…