快捷方式

Float8ActInt4WeightQATQuantizer

class torchao.quantization.qat.Float8ActInt4WeightQATQuantizer(group_size: Optional[int] = 64, scale_precision: dtype = torch.bfloat16)[源代码]

QAT 量化器,用于将模型中的线性层应用动态逐行 float8 激活 + 每组/每通道 int4 对称权重伪量化。目前仅支持 float8 激活的逐行粒度。

参数:
  • group_size (Optional[int]) – 权重的每个量化组中的元素数量,默认为 64。对于每通道使用 None。

  • scale_precision – 权重缩放的精度,默认为 torch.bfloat16。

prepare(model: Module, *args: Any, **kwargs: Any) Module[源代码]

将所有 nn.Linear 替换为 FakeQuantizedLinear,其中激活使用 float8 伪量化器,权重使用 int4 伪量化器。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源