autoquant¶
- torchao.quantization.autoquant(model, example_input=None, qtensor_class_list=[<class 'torchao.quantization.autoquant.AQDefaultLinearWeight'>, <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>], filter_fn=None, mode=['interpolate', 0.85], manual=False, set_inductor_config=True, supress_autoquant_errors=True, min_sqnr=None, **aq_kwargs)[源]¶
自动量化是一个过程,它在一组潜在的量化张量子类中,识别出对模型的每一层进行量化的最快方法。
自动量化分为三个步骤
1-准备模型:在模型中搜索 Linear 层,将其权重替换为 AutoQuantizableLinearWeight。
2-形状校准:用户在一个或多个输入上运行模型,记录 AutoQuantizableLinearWeight 所看到的激活形状/dtype 的详细信息,以便我们知道在步骤 3 中优化量化操作时要使用的形状/dtype。
- 3-完成自动量化:对于每个 AutoQuantizableLinearWeight,针对 qtensor_class_list 中的每个成员,在每种形状/dtype 上运行基准测试。
选择最快的选项,从而得到一个高性能的模型。
此 autoquant 函数执行步骤 1。步骤 2 和 3 可以通过简单地运行模型来完成。如果提供了 example_input,此函数也会运行模型(完成步骤 2 和 3)。此 autoquant API 可以处理已经应用了 torch.compile 的模型,在这种情况下,一旦模型运行并量化,torch.compile 过程通常也会继续进行。
为了优化输入形状/dtype 的组合,用户可以将 manual=True,使用所有所需的形状/dtype 运行模型,然后在所需的输入集已记录后调用 model.finalize_autoquant 来完成量化。
- 参数:
model (torch.nn.Module) – 要自动量化的模型。
example_input (任意, 可选) – 模型的示例输入。如果提供,函数将对此输入执行前向传播(除非 manual=True,否则将完全自动量化模型)。默认为 None。
qtensor_class_list (列表, 可选) – 用于量化的张量类列表。默认为 DEFAULT_AUTOQUANT_CLASS_LIST。
filter_fn (可调用对象, 可选) – 应用于模型参数的过滤函数。默认为 None。
mode (列表, 可选) – 包含量化模式设置的列表。第一个元素是模式类型(例如,“interpolate”),第二个元素是模式值(例如,0.85)。默认为 [“interpolate”, .85]。
manual (布尔值, 可选) – 是否在单次运行后停止形状校准并进行自动量化(默认,False),或者等待用户调用 model.finalize_autoquant (True) 以便记录具有多个形状/dtype 的输入。
set_inductor_config (布尔值, 可选) – 是否自动使用推荐的 Inductor 配置设置(默认为 True)。
supress_autoquant_errors (布尔值, 可选) – 是否在自动量化期间抑制错误。(默认为 True)。
min_sqnr (浮点数, 可选) – 量化层输出与非量化层输出之间的最小可接受信噪比(https://en.wikipedia.org/wiki/Signal-to-quantization-noise_ratio),用于过滤。
impact (导致过大数值的量化方法) –
合理 (用户可以从一个合理的数字开始) –
结果 (比如 40,并根据结果进行调整) –
**aq_kwargs – 自动量化过程的额外关键字参数。
- 返回:
- 自动量化并封装的模型。如果提供了 example_input,函数将对输入执行前向传播。
并返回前向传播的结果。
- 返回类型: