编写自己的量化张量¶
torchao 中的量化建立在张量子类的基础上。它们是 torchao 的主要扩展点,用于使用低精度计算提供灵活的推理和训练支持,同时与 torch.compile、autograd 和分布式原语等重要 PyTorch 功能进行组合。
在本教程中,我们将重点介绍与模块交换相比,利用张量子类的好处,并逐步介绍如何使用此方法表达量化的简单示例。
什么是张量子类?¶
张量子类是简单地继承自 torch.Tensor 的类。它们允许用户在模型中现有操作之间插入自定义计算逻辑,从而使顶级 torch 命名空间中的函数(如 torch.add)能够继续无缝工作。
张量子类方法的明显替代方案是模块交换:例如,只需将模型中的所有 nn.Linear 模块替换为自定义的 Int8QuantizedLinear 模块。与此方法相比,使用张量子类有几个重要好处
更细粒度的集成点。 模块交换在模块级别拦截计算,因此不适用于依赖 torch 函数或原生模块变体(例如,nn.Linear 的略微修改版本)的模型。相比之下,由于张量子类在函数/操作级别拦截计算,因此只要使用相同的函数/操作,我们就可以对模型进行量化。
更好的可组合性。 使用模块交换组合多个功能会很笨拙。例如,组合两个现有的 Int8QuantizedLinear 和 DistributedLinear 模块将要求用户创建另一个复制这些功能的线性类。张量子类通过简单地将一个子类包装在另一个子类中来解决此问题。如果外部张量(例如 DTensor)知道内部张量已量化,这也可以提供性能优势,因此可以使用更少的网络和内存带宽执行昂贵的 allgather 操作。
重用 PyTorch 组件。 使用张量子类表达量化是很自然的,因为量化张量只是具有不同 dtype 的 torch.Tensor。模型结构不变(nn.Linears 仍然是 nn.Linears),因此后续优化过程也可以保持与以前完全相同。
在本教程的其余部分,我们将通过一个示例来介绍如何使用这两种方法表达量化。有关张量子类的更多阅读,请参阅
使用模块交换进行量化¶
我们首先举一个简单的例子,说明如何使用模块交换实现 int8 对称仅权重量化。所有代码都可以在此示例脚本中找到。我们将使用以下函数将 float32 张量量化为 int8 张量
from typing import Tuple
import torch
def int8_symmetric_quantize(
fp32_tensor: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Symmetrically quantize the torch.float32 tensor into torch.int8.
Return a 2-tuple of (quantized value, scale).
input: dimensions=[M, N], dtype=torch.float32
output: dimensions=[M, N], dtype=torch.int8
scale: dimensions=[M, 1], dtype=torch.float32
"""
quant_min = -128
quant_max = 127
min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False)
max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
scale = scale.view(fp32_tensor.shape[0], -1)
out = torch.round(fp32_tensor * (1.0 / scale))
out = torch.clamp(out, quant_min, quant_max).to(torch.int8)
return out, scale
接下来,我们将创建一个新的 QuantizedLinear 模块,该模块调用此函数以动态量化权重
class QuantizedLinear(torch.nn.Linear):
"""
Linear module that performs dynamic and symmetric weight-only
int8 quantization.
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
w_int8, scale = int8_symmetric_quantize(self.weight)
return torch.matmul(x, w_int8.t().to(x.dtype)) * scale.t()
@classmethod
def from_float(cls, mod: torch.nn.Linear):
new_linear = cls(mod.in_features, mod.out_features, mod.bias)
new_linear.weight = mod.weight
return new_linear
然后,剩下的唯一事情就是将模型中的所有 nn.Linear 模块替换为新的 QuantizedLinear 模块。让我们使用以下玩具模型进行演示
import copy
class ToyModel(torch.nn.Module):
def __init__(self, m: int, n: int, k: int):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
float_model = ToyModel(64, 128, 32).cuda()
quantized_model = copy.deepcopy(float_model)
# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model.named_children():
if type(child) == torch.nn.Linear:
new_linear = QuantizedLinear.from_float(child)
setattr(quantized_model, name, new_linear)
验证模型现在是否使用我们的 QuantizedLinear 模块。此模型现在可以使用了!
>>> print(float_model)
ToyModel(
(linear1): Linear(in_features=64, out_features=128, bias=False)
(linear2): Linear(in_features=128, out_features=32, bias=False)
)
>>> print(quantized_model)
ToyModel(
(linear1): QuantizedLinear(in_features=64, out_features=128, bias=False)
(linear2): QuantizedLinear(in_features=128, out_features=32, bias=False)
)
这种简单方法的一个重要缺点是灵活性。目前这仅适用于原生 PyTorch 模块,但如果模型有略微修改的线性模块(例如,支持分布式训练)怎么办?它也不适用于直接调用线性函数版本 (torch.nn.functional.linear) 的模型。
此外,假设我们想将此功能与通过模块交换实现的分布式功能组合。除了创建另一个结合了这两种功能的模块之外,没有干净的方法可以做到这一点。这些限制可以通过张量子类解决,这是在模型中插入自定义计算(如量化)的更优雅方式。
使用张量子类进行量化¶
在这里,我们将使用基于 __torch_dispatch__ 的张量子类重新实现上述量化技术。
张量子类(通常使用 __torch_dispatch__)是 PyTorch 中一个非常强大/灵活的扩展点。它们作为扩展点主要有两个目的
张量子类允许您覆盖(几乎)每个 PyTorch API 的实现,并且在实现其他 PyTorch 产品时大量使用
张量子类允许您将张量数据与附加元数据耦合。一些示例
[量化] 比例/零点元数据(AffineQuantizedTensor)
[不规则性] 不规则结构元数据(NestedTensor,文档)
其他一些对张量子类感兴趣的资源
__torch_dispatch__ 文档(链接)
__torch_dispatch__ 是什么(以及为什么)(链接)
使用 __torch_dispatch__ 实现 FlopCounter 和 MemoryTracker 的 Google Colab(链接)
话不多说,让我们从为对称量化定义骨干张量子类开始
class Int8SymmetricTensor(torch.Tensor):
"""
Our subclass represents a tensor that has been quantized to int8
It will hold two inner tensors:
int_data: int8[M, N]
scale: fp32[M, 1]
"""
@staticmethod
@torch._dynamo.disable
def __new__(cls, int_data: torch.Tensor, scale: torch.Tensor):
return torch.Tensor._make_wrapper_subclass(
cls,
int_data.shape,
strides=int_data.stride(),
storage_offset=int_data.storage_offset(),
dtype=scale.dtype,
device=int_data.device,
)
@torch._dynamo.disable
def __init__(self, int_data: torch.Tensor, scale: torch.Tensor):
# inner data expected to be quantized already
assert int_data.dtype is torch.int8
# we could do more work to support ndim > 2!
assert int_data.ndim == 2
assert scale.ndim == 2
self.int_data = int_data
self.scale = scale
def __tensor_flatten__(self) -> Tuple[List[str], Any]:
"""
Returns a tuple of:
names of all inner tensor attributes (two in our case)
any other additional, non-tensor metadata.
Needed for PT2 support.
"""
return ["int_data", "scale"], None
@classmethod
def __tensor_unflatten__(cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None):
"""
__tensor_unflatten__ should effectively undo __tensor_flatten__.
inputs:
a dict mapping names of inner tensor attributes back to the tensors
the constant metadata from __tensor_flatten__
output:
a new instance of your subclass
Needed for PT2 support.
"""
assert extra_metadata is None
int_data = tensor_data_dict["int_data"]
scale = tensor_data_dict["scale"]
return Int8SymmetricTensor(int_data, scale)
def __repr__(self):
return f'Int8SymmetricTensor(int_data={repr(self.int_data)}, scale={repr(self.scale)})'
@staticmethod
def from_float(float_tensor):
"""
Actually performs the symmetric quantization.
In our simple inference example we will quantize weights "ahead-of-time",
although later in a training example we can quantize/dequantize
during model execution, inside of our __torch_dispatch__
input:
float32 torch.Tensor
output:
Int8SymmetricTensor
"""
int8_tensor, scale = int8_symmetric_quantize(float_tensor)
return Int8SymmetricTensor(int8_tensor, scale)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
"""
Called for each ATen operator that our subclass is passed as an input to.
We need to define our own implementation for every operator here.
"""
if kwargs is None:
kwargs = {}
if func not in op_implementations_dict:
raise AssertionError(f'Int8SymmetricTensor does not yet support op: {str(func)}')
return op_implementations_dict[func](func, *args, **kwargs)
# Convenience function for registering our own implementation
# to every ATen operator in PyTorch
op_implementations_dict = {}
def register_op(ops: List[torch._ops.OpOverload]):
def impl_decorator(op_impl):
global op_implementations_dict
for op in ops:
op_implementations_dict[op] = op_impl
return op_impl
return impl_decorator
在上面的代码中,我们做了几件事
定义了一个基本的“包装器”张量子类——它实际上是一个容器对象,包含一些内部数据(特别是对应于我们的 int8 数据和比例的两个张量)
定义了一个 __torch_dispatch__ 实现,当模型对我们的任何子类输入调用任何 ATen 运算符时,都会调用它
(对于 PT2 支持)定义了 __tensor_flatten__/__tensor_unflatten__ 方法。这是我们子类要与 torch.compile 一起工作所需的最大要求之一(稍后会详细介绍)。它有效地告诉 torch.compile 如何将我们的子类“解糖”为它的内部组件。
(对于 PT2 支持)为两个构造函数方法(__new__ 和 __init__)添加了 torch._dynamo.disable 装饰器(稍后会详细介绍)。
我们应该实现哪些运算符?¶
PyTorch 有一个相当大的运算符接口。我们不应该尝试让新的张量子类达到 100% 的覆盖率,而应该专注于我们上面玩具模型所需的运算符。
那么,我们的模型中调用了哪些运算符呢,这样我们才能知道首先要实现什么?暴力方法是重复运行模型,查看子类中出现哪些运算符错误。更优雅的方法是记录模型在执行期间看到的所有运算符。这可以通过另一个 LoggingTensor 子类来实现,如此示例所示。
让我们实现下面必要的运算符
from torch.utils._python_dispatch import return_and_correct_aliasing
@register_op([torch.ops.aten.mm.default])
def int8_mm(func, x, weight):
assert isinstance(weight, Int8SymmetricTensor), "Int8SymmetricTensor: matmul currently only supports the weight in low precision, not the input!"
return torch.mm(x, weight.int_data.to(x.dtype)) * weight.scale
@register_op([
torch.ops.aten.detach.default,
torch.ops.aten.t.default,
])
def int8_view_ops(func, *args, **kwargs):
assert isinstance(args[0], Int8SymmetricTensor)
out_data = func(args[0].int_data, *args[1:], **kwargs)
out_scale = func(args[0].scale, *args[1:], **kwargs)
out = Int8SymmetricTensor(out_data, out_scale)
return return_and_correct_aliasing(func, args, kwargs, out)
你会很快注意到一件事:我们的模型本身由几个线性层组成,但我们看到一些操作,如 aten.t 和 aten.mm 击中了我们的子类。一些背景
我们有许多 C++ 中存在的运算符分解,它们运行在张量子类“之上”。linear 就是其中之一(分解位于此处)
分解在某种意义上是好的,因为它们缩小了作为子类作者必须实现的 API 大小。但如果你宁愿覆盖“更高级别”的运算符而不是其分解中的底层操作,它们可能会很痛苦。
如果你更喜欢在更高级别覆盖某些操作(如 Linear),你可以使用 __torch_function__ (示例) 来做到这一点。值得注意的是,如果你想要自动梯度支持,那么你在 __torch_function__ 层执行的任何覆盖都需要以可微分的方式编写,而你在 __torch_dispatch__ 中执行的任何覆盖将自动可微分。
我们的实现中有一些值得指出的细微差别
你会注意到我们不再需要在 mm 实现中转置权重/比例。那是因为转置在到达 aten.mm 操作之前“已经发生”了。
我们的 aten.mm 实现不返回张量子类输出。从这个意义上说,我们量化子类的“传播”以矩阵乘法结束。这映射到我们的权重是低精度的,但我们需要以高精度执行矩阵乘法本身的事实。通常,子类作者可以自由选择他们的子类对哪些操作进行传播或不传播。如果你希望模型中的每个函数都被量化(包括所有逐点和缩减操作),你可以编写子类实现来量化每个操作的输出并始终返回一个子类。
我们能够对 4 个视图操作重用相同的实现。通常,许多操作可能使用一个非常通用的实现:解包任何子类输入,对内部张量运行底层操作符,并将输出包装回子类。
然而,你是否总是可以重用一个实现,取决于你正在尝试做什么。例如,我们在子类上实现了 transpose(dim0, dim1),通过对内部数据和内部比例张量调用相同的转置。如果我们的比例和数据张量具有不同的维度数量,这将不起作用,因此在这种情况下,转置将需要自定义实现。
比较输出¶
完成所有这些之后,让我们用两种量化版本运行模型,并确认它们给出相同的输出!
float_model = ToyModel(64, 128, 32).cuda()
quantized_model_module_swap = copy.deepcopy(float_model)
quantized_model_subclass = copy.deepcopy(float_model)
# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model_module_swap.named_children():
if type(child) == torch.nn.Linear:
new_linear = QuantizedLinear.from_float(child)
setattr(quantized_model_module_swap, name, new_linear)
# Swap torch.nn.Linear weights with Int8SymmetricTensor subclasses
for name, child in quantized_model_subclass.named_children():
if type(child) == torch.nn.Linear:
subclass_param = Int8SymmetricTensor.from_float(child.weight)
child.weight = torch.nn.Parameter(subclass_param, requires_grad=True)
with torch.no_grad():
x = torch.randn(64, 64, 64, device='cuda')
out_module_swap = quantized_model_module_swap(x)
out = quantized_model_subclass(x)
print(torch.allclose(out, out_module_swap)) # prints True
# We can also use torch.compile to fuse some of our quantized logic
out_compiled = torch.compile(quantized_model_subclass)(x)
print(torch.allclose(out, out_compiled)) # prints True
下一步¶
在本教程中,我们演示了如何构建一个简单的量化张量子类。这是本系列两个教程的第一部分。下一篇文章将讨论如何向张量子类添加更高级的功能,例如使其可训练、与 DTensors 组合以及添加张量并行支持。有关 torchao 中 AffineQuantizedTensor 如何使用张量子类构建的更详细示例,另请参阅此示例。
如果您在实现子类时有任何疑问,请随时在此处提交问题。