• 文档 >
  • 编写自己的量化张量
快捷方式

编写您自己的量化张量

torchao 中的量化建立在张量子类(tensor subclasses)的基础上。它们是 torchao 提供灵活推理和训练支持的主要扩展点,通过低精度计算,同时与 torch.compile、autograd 和分布式原语等重要的 PyTorch 特性相结合。

在本教程中,我们将重点介绍与模块替换(module swaps)相比,利用张量子类的好处,并通过一个简单的示例来演示如何使用这种方法来表达量化。

什么是张量子类?

张量子类只是继承自 torch.Tensor 的类。它们允许用户在模型中现有的操作之间插入自定义计算逻辑,这样像 torch.add 这样的顶级 torch 命名空间中的函数将继续无缝工作。

与张量子类方法显而易见的替代方法是模块替换:例如,只需将模型中的所有 nn.Linear 模块替换为您自定义的 Int8QuantizedLinear 模块。与此方法相比,使用张量子类有几个重要的好处:

  1. 更精细的集成点。 模块替换在模块级别拦截计算,因此对于依赖 torch 函数或原生模块变体的模型(例如,nn.Linear 的稍作修改的版本)无效。相比之下,由于张量子类在函数/操作级别拦截计算,只要使用相同的函数/操作,我们就可以量化模型。

  2. 更好的可组合性。 使用模块替换组合多个功能很麻烦。例如,组合两个现有的 Int8QuantizedLinear 和 DistributedLinear 模块需要用户创建一个另一个线性类,该类复制这些功能。张量子类通过简单地将一个子类包装在另一个子类中来绕过此问题。如果外部张量(例如 DTensor)意识到内部张量已被量化,这也可以提供性能优势,从而可以使用更少的网络和内存带宽执行昂贵的 allgather 操作。

  3. 重用 PyTorch 组件。 使用张量子类来表达量化是很自然的,因为量化张量只是具有不同 dtype 的 torch.Tensors。模型结构保持不变(nn.Linears 仍然是 nn.Linears),因此后续的优化传递也可以与之前完全相同。


在教程的其余部分,我们将通过一个示例来演示如何使用这两种方法来实现量化。有关张量子类的更多阅读,请参考:

通过模块替换进行量化

我们首先通过一个简单的示例来实现 int8 仅权重量化,方法是使用模块替换。所有代码都可以在这个 示例脚本 中找到。我们将使用以下函数将 float32 张量量化为 int8 张量:

from typing import Tuple
import torch

def int8_symmetric_quantize(
    fp32_tensor: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Symmetrically quantize the torch.float32 tensor into torch.int8.
    Return a 2-tuple of (quantized value, scale).

    input: dimensions=[M, N], dtype=torch.float32
    output: dimensions=[M, N], dtype=torch.int8
    scale: dimensions=[M, 1], dtype=torch.float32
    """
    quant_min = -128
    quant_max = 127
    min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False)
    max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False)
    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
    max_val_pos = torch.max(-min_val_neg, max_val_pos)
    scale = max_val_pos / (float(quant_max - quant_min) / 2)
    scale = scale.view(fp32_tensor.shape[0], -1)
    out = torch.round(fp32_tensor * (1.0 / scale))
    out = torch.clamp(out, quant_min, quant_max).to(torch.int8)
    return out, scale

接下来,我们将创建一个新的 QuantizedLinear 模块,它调用此函数来动态量化权重:

class QuantizedLinear(torch.nn.Linear):
    """
    Linear module that performs dynamic and symmetric weight-only
    int8 quantization.
    """
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w_int8, scale = int8_symmetric_quantize(self.weight)
        return torch.matmul(x, w_int8.t().to(x.dtype)) * scale.t()

    @classmethod
    def from_float(cls, mod: torch.nn.Linear):
        new_linear = cls(mod.in_features, mod.out_features, mod.bias)
        new_linear.weight = mod.weight
        return new_linear

然后,唯一剩下的就是将模型中的所有 nn.Linear 模块替换为我们的新 QuantizedLinear。让我们使用以下玩具模型进行演示:

import copy

class ToyModel(torch.nn.Module):
    def __init__(self, m: int, n: int, k: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, n, bias=False)
        self.linear2 = torch.nn.Linear(n, k, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

float_model = ToyModel(64, 128, 32).cuda()
quantized_model = copy.deepcopy(float_model)

# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model.named_children():
    if type(child) == torch.nn.Linear:
        new_linear = QuantizedLinear.from_float(child)
        setattr(quantized_model, name, new_linear)

验证模型现在是否使用了我们的 QuantizedLinear 模块。该模型现在已准备就绪!

>>> print(float_model)
ToyModel(
  (linear1): Linear(in_features=64, out_features=128, bias=False)
  (linear2): Linear(in_features=128, out_features=32, bias=False)
)

>>> print(quantized_model)
ToyModel(
  (linear1): QuantizedLinear(in_features=64, out_features=128, bias=False)
  (linear2): QuantizedLinear(in_features=128, out_features=32, bias=False)
)

这种简单方法的一个重要缺点是灵活性。目前这仅适用于原生 PyTorch 模块,但如果模型具有稍作修改的线性模块,例如支持分布式训练,该怎么办?如果模型直接调用线性(torch.nn.functional.linear)的功能版本,它也将无效。

此外,假设我们想将此功能与分布相结合,分布也通过模块替换实现。除了创建另一个组合了这两个功能的模块之外,没有干净的方法可以做到这一点。这些限制可以通过张量子类来解决,张量子类是拦截模型中自定义计算(如量化)的一种更优雅的方式。

通过张量子类进行量化

在这里,我们将使用一个基于 __torch_dispatch__ 的张量子类来重新实现上述量化技术。

张量子类(通常利用 __torch_dispatch__)是 PyTorch 中一个非常强大/灵活的扩展点。它们作为扩展点有两个主要目的:

  1. 张量子类允许您覆盖(几乎)每个 PyTorch API 的**实现**,并且在很大程度上用于实现其他 PyTorch 产品。

  2. 张量子类允许您将张量数据与附加**元数据进行耦合**。一些示例:

    1. [分布式] 关于张量如何在各个节点之间分片(DTensor文档)的元数据

    2. [量化] 尺度/零点元数据(AffineQuantizedTensor

    3. [不规则性] 关于不规则结构(NestedTensor文档)的元数据

一些关于张量子类的其他资源,供感兴趣的读者参考:

  1. __torch_dispatch__ 文档(链接

  2. 什么是 __torch_dispatch__(以及为什么使用它)链接

  3. 使用 __torch_dispatch__ 实现 FlopCounter 和 MemoryTracker 的 Google Colab(链接

话不多说,让我们开始定义我们最基本的对称量化张量子类:

class Int8SymmetricTensor(torch.Tensor):
    """
    Our subclass represents a tensor that has been quantized to int8
    It will hold two inner tensors:
      int_data: int8[M, N]
      scale: fp32[M, 1]
    """

    @staticmethod
    @torch._dynamo.disable
    def __new__(cls, int_data: torch.Tensor, scale: torch.Tensor):
        return torch.Tensor._make_wrapper_subclass(
            cls,
            int_data.shape,
            strides=int_data.stride(),
            storage_offset=int_data.storage_offset(),
            dtype=scale.dtype,
            device=int_data.device,
        )

    @torch._dynamo.disable
    def __init__(self, int_data: torch.Tensor, scale: torch.Tensor):
        # inner data expected to be quantized already
        assert int_data.dtype is torch.int8
        # we could do more work to support ndim > 2!
        assert int_data.ndim == 2
        assert scale.ndim == 2
        self.int_data = int_data
        self.scale = scale

    def __tensor_flatten__(self) -> Tuple[List[str], Any]:
        """
        Returns a tuple of:
          names of all inner tensor attributes (two in our case)
          any other additional, non-tensor metadata.

        Needed for PT2 support.
        """
        return ["int_data", "scale"], None

    @classmethod
    def __tensor_unflatten__(cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None):
        """
         __tensor_unflatten__ should effectively undo __tensor_flatten__.

        inputs:
          a dict mapping names of inner tensor attributes back to the tensors
          the constant metadata from __tensor_flatten__
        output:
          a new instance of your subclass

        Needed for PT2 support.
        """
        assert extra_metadata is None
        int_data = tensor_data_dict["int_data"]
        scale = tensor_data_dict["scale"]
        return Int8SymmetricTensor(int_data, scale)

    def __repr__(self):
        return f'Int8SymmetricTensor(int_data={repr(self.int_data)}, scale={repr(self.scale)})'

    @staticmethod
    def from_float(float_tensor):
        """
        Actually performs the symmetric quantization.
        In our simple inference example we will quantize weights "ahead-of-time",
        although later in a training example we can quantize/dequantize
        during model execution, inside of our __torch_dispatch__

        input:
          float32 torch.Tensor
        output:
          Int8SymmetricTensor
        """
        int8_tensor, scale = int8_symmetric_quantize(float_tensor)
        return Int8SymmetricTensor(int8_tensor, scale)

    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs):
        """
        Called for each ATen operator that our subclass is passed as an input to.
        We need to define our own implementation for every operator here.
        """
        if kwargs is None:
            kwargs = {}
        if func not in op_implementations_dict:
            raise AssertionError(f'Int8SymmetricTensor does not yet support op: {str(func)}')
        return op_implementations_dict[func](func, *args, **kwargs)


# Convenience function for registering our own implementation
# to every ATen operator in PyTorch
op_implementations_dict = {}
def register_op(ops: List[torch._ops.OpOverload]):
    def impl_decorator(op_impl):
        global op_implementations_dict
        for op in ops:
            op_implementations_dict[op] = op_impl
        return op_impl

    return impl_decorator

在上面的代码中,我们做了几件事:

  1. 定义了一个基本的“包装器”张量子类——它实际上是一个容器对象,保存了一些内部数据(特别是两个张量,对应于我们的 int8 数据和尺度)

  2. 定义了一个 __torch_dispatch__ 实现,对于我们模型对任何子类输入调用的每个 ATen 操作都会调用它

  3. (为了支持 PT2)定义了一个 __tensor_flatten__/__tensor_unflatten__ 方法。这是我们的子类与 torch.compile 兼容的一些要求中最重要的部分(稍后会详细介绍)。它有效地告诉 torch.compile 如何将我们的子类“解糖”成其内部组件。

  4. (为了支持 PT2)在构造方法(__new____init__)上添加了 torch._dynamo.disable 装饰器(稍后会详细介绍)。

应该实现哪些操作?

PyTorch 拥有相当大的操作表面。与其试图让我们的新张量子类实现 100% 的覆盖,不如让我们专注于玩具模型所需的那些操作。

但是,我们的模型调用了哪些操作,这样我们才知道首先要实现什么?暴力方法是反复运行模型,查看子类中出现的错误操作。更优雅的方法是记录模型在执行过程中看到的每个操作。这可以通过另一个 LoggingTensor 子类来实现,如此示例所示。

让我们在下面实现必要的操作:

from torch.utils._python_dispatch import return_and_correct_aliasing

@register_op([torch.ops.aten.mm.default])
def int8_mm(func, x, weight):
    assert isinstance(weight, Int8SymmetricTensor), "Int8SymmetricTensor: matmul currently only supports the weight in low precision, not the input!"
    return torch.mm(x, weight.int_data.to(x.dtype)) * weight.scale

@register_op([
    torch.ops.aten.detach.default,
    torch.ops.aten.t.default,
])
def int8_view_ops(func, *args, **kwargs):
    assert isinstance(args[0], Int8SymmetricTensor)
    out_data = func(args[0].int_data, *args[1:], **kwargs)
    out_scale = func(args[0].scale, *args[1:], **kwargs)
    out = Int8SymmetricTensor(out_data, out_scale)
    return return_and_correct_aliasing(func, args, kwargs, out)

您会很快注意到一件事:我们的模型本身由几个线性层组成,但我们看到像 aten.taten.mm 这样的操作击中了我们的子类。一些背景信息:

  • 我们在 C++ 中有许多操作分解,它们运行在张量子类“之上”。linear 就是这样一个操作(分解位于 此处)。

  • 分解可能是好的,因为它们缩小了您作为子类作者需要实现的大 API 表面积。但如果宁愿覆盖“更高层”的操作而不是其分解中的底层操作,它们可能会很麻烦。

  • 如果您宁愿在更高层覆盖某些操作(例如 Linear),您可以使用 __torch_function__示例)来实现。值得注意的是,如果您想要自动微分支持,那么您在 __torch_function__ 层执行的任何覆盖都需要以可微分的方式编写,而您在 __torch_dispatch__ 中执行的任何覆盖都将自动可微分。

我们的实现中有一些细微之处值得指出:

  1. 您会注意到,我们在 mm 实现中不再需要转置权重/尺度。这是因为在 aten.mm 操作发生之前,转置“已经”完成了。

  2. 我们的 aten.mm 实现**不**返回张量子类输出。从这个意义上说,我们的量化子类的“传播”在矩阵乘法处结束。这对应于我们的权重是低精度的,但我们需要执行高精度的矩阵乘法的事实。总的来说,子类作者可以自由选择他们的子类会传播或不传播哪些操作。如果您希望模型中的每个函数都被量化(包括所有逐点和归约操作),您可以编写您的子类实现来量化每个操作的输出,并始终返回一个子类。

  3. 我们能够为 4 个视图操作重用相同的实现。总的来说,许多操作可能适用于相当通用的实现:解开任何子类输入,在内部张量上运行底层操作,并将输出重新包装到子类中。

    • 然而,是否总是可以重用实现取决于您想做什么。例如,我们通过对内部数据和内部尺度张量执行相同的转置来实现我们子类的 transpose(dim0, dim1)。如果我们的尺度和数据张量具有不同数量的维度,这种情况就不起作用了,因此在这种情况下,转置需要自定义实现。

比较输出

在完成所有这些之后,让我们用两种量化版本运行我们的模型,并确认它们给出相同的输出!

float_model = ToyModel(64, 128, 32).cuda()
quantized_model_module_swap = copy.deepcopy(float_model)
quantized_model_subclass = copy.deepcopy(float_model)

# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model_module_swap.named_children():
    if type(child) == torch.nn.Linear:
        new_linear = QuantizedLinear.from_float(child)
        setattr(quantized_model_module_swap, name, new_linear)

# Swap torch.nn.Linear weights with Int8SymmetricTensor subclasses
for name, child in quantized_model_subclass.named_children():
    if type(child) == torch.nn.Linear:
        subclass_param = Int8SymmetricTensor.from_float(child.weight)
        child.weight = torch.nn.Parameter(subclass_param, requires_grad=True)

with torch.no_grad():
    x = torch.randn(64, 64, 64, device='cuda')
    out_module_swap = quantized_model_module_swap(x)
    out = quantized_model_subclass(x)
    print(torch.allclose(out, out_module_swap))  # prints True

    # We can also use torch.compile to fuse some of our quantized logic
    out_compiled = torch.compile(quantized_model_subclass)(x)
    print(torch.allclose(out, out_compiled))  # prints True

下一步

在本教程中,我们演示了如何构建一个简单的量化张量子类。这是本系列教程的第一部分。 下一篇文章将讨论如何为您的张量子类添加更高级的功能,例如使其可训练、与 DTensors 组合以及添加张量并行支持。有关 torchao 中 AffineQuantizedTensor 如何使用张量子类构建的更详细示例,请参阅 此示例

如果您在实现子类时有任何疑问,请随时在此处 提出问题

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源