快捷方式

TorchAOBaseTensor

class torchao.utils.TorchAOBaseTensor[源代码]
一个工具张量子类,提供常用函数

新的张量子类可以继承它来获得所有实用功能

class MyTensor(TorchAOBaseTensor)

pass

这包括
_get_to_kwargs 可以获取 to 的 kwargs
class MyTensor(TorchAOBaseTensor)
def to(self, *args, **kwargs)

kwargs = _get_to_kwargs(*args, **kwargs) …

实现了:

implements = MyTensor.implements

@implements(torch.nn.functional.linear): def _(func, types, args, kwargs)

register_layout:

register_layout = MyTensor.register_layout

@register_layout(PlainLayout) class PlainAQTTensorImpl(…)

get_tensor_impl_constructor:

get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor # 在 MyTensor 的构造函数中: tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)

用于简化张量子类实现的类变量
tensor_data_names (List[str]): 所有必需的 tensor_data 的名称列表,顺序应匹配

张量子类的 __init__ 列表

tensor_attribute_names (List[str]): 非 Tensor 属性的名称列表,

顺序应与张量子类的 __init__ 列表匹配,后跟所有 tensor_data_names 参数

optional_tensor_data_names (List[str]): 可选定义此字段以为您实现额外的样板函数,但这在有任何可选的 Tensor 数据属性时是必需的,定义时,这将是可选项的 Tensor 的名称列表 optional_tensor_attribute_names (List[str]): 可选定义此字段以为您实现额外的样板函数,但这在有任何可选的非 Tensor 属性时是必需的,定义时,这将是可选项的属性的名称列表 注意:`__init__` 和 `__new__` 中的参数顺序应与 `tensor_data_names` + `tensor_attribute_names` + `optional_tensor_data_names` (如果存在) + `optional_tensor_attribute_names` (如果存在) 完全匹配。

如果定义了 tensor_data_namestensor_attribute_names,则会添加一些额外的函数,包括: __tensor_flatten__:展平子类化的张量实例,返回一个元组,第一个元素是有效张量数据的名称,

第二个元素是非 Tensor 属性的列表

__tensor_unflatten__:接受一个 `tensor_data_dict`(张量名称到张量的映射)和非张量属性列表,返回子类化张量的新实例 _apply_fn_to_data:接受一个函数(Tensor -> Tensor),将函数应用于所有张量数据并

用转换后的张量数据重新创建一个子类化张量

__repr__:子类化张量实例的字符串表示形式 _same_metadata:返回 cls 实例之间元数据是否相同 __setstate__:加载序列化的张量子类检查点时,它会将旧检查点中保存的新可选张量和张量属性设置为 None,以在将新的可选张量数据或属性添加到张量子类时保持旧检查点的向后兼容性。 PyTorch 操作:torch.Tensor.contiguous ATen 操作:aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (启用 t.to)

示例

class MyTensor(torch.Tensor)

tensor_data_names = [“a”, “b”] tensor_attribute_names = [“c”, “d”] optional_tensor_data_names = [“e”, “f”] optional_tensor_attribute_names = [“g”, “h”]

def __new__(

cls, a: Tensor, b: Tensor, c: int, d: str, e: Optional[Tensor] = None, f: Optional[Tensor] = None, g: Optional[int] = None, h: Optional[int] = None,

):

pass

def __init__(

self, a: Tensor, b: Tensor, c: int, d: str e: Optional[Tensor] = None, f: Optional[Tensor] = None, g: Optional[int] = None, h: Optional[int] = None,

):

pass

classmethod get_tensor_impl_constructor(layout_class: Callable) Callable

获取 tensor_class 的 TensorImpl 类构造函数 (TensorImplClass.from_plain),基于 layout_class layout_class 表示 `Layout` 的子类类型,例如 PlainLayout

参数:
  • tensor_class – 张量子类类型

  • layout_class – `Layout` 的子类类型,例如 PlainLayout

返回:

layout_class 的 tensor impl 子类构造函数

classmethod implements(aten_ops_or_torch_fns)

使用此装饰器为 `__torch_dispatch__` 中的 aten 操作(如果用户传入了一个操作列表)或 `__torch_function__` 中的 torch 函数(如果用户传入了一个单一对象)实现一个函数。

class MyTensor(torch.Tensor)

… implements = classmethod(_implements)

implements = MyTensor.implements

@implements(torch.nn.functional.linear): def _(func, types, args, kwargs)

classmethod register_layout(layout_class: Callable)

布局注册的辅助函数,用于实现每个张量子类的 `register_layout` 装饰器,请参阅 aqt.py 中的示例用法

参数:
  • tensor_class – 张量子类类型

  • layout_class – `Layout` 的子类类型,例如 PlainLayout

返回:

一个在表中注册 tensor impl 构造函数的装饰器

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源