• 文档 >
  • (第 2 部分)使用 QAT、QLoRA 和 float8 进行微调
快捷方式

(第 2 部分) 使用 QAT、QLoRA 和 float8 进行微调

TorchAO 通过利用我们集成到合作框架中的量化和稀疏技术,提供端到端的预训练、微调和模型服务优化流程。这是展示此端到端流程的 3 个教程中的第 2 部分,重点介绍微调步骤。

_images/e2e_flow_part2.png

微调是使预训练模型适应更多领域特定数据的重要步骤。在本教程中,我们将演示可在微调期间应用于模型的 3 种模型优化技术

1. **量化感知训练 (QAT)**,用于在微调期间使模型适应量化数值,目标是在模型最终量化(例如在服务步骤中)时,减轻微调模型中的量化降级。请查看[我们的博客](https://pytorch.ac.cn/blog/quantization-aware-training/)和[自述文件](https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md)了解更多详细信息!

2. **量化低秩适应 (QLoRA)**,通过引入小的、可训练的低秩矩阵并冻结原始预训练检查点来降低微调的资源需求,这是一种参数高效微调 (PEFT)。请参阅[原始论文](https://arxiv.org/pdf/2305.14314)了解更多详细信息。

3. **Float8 量化微调**,通过将高精度权重和激活动态量化为 float8 来加速微调,类似于[用 float8 进行预训练](pretraining.html)。

量化感知训练 (QAT)

量化感知训练的目标是在训练或微调期间使模型适应量化数值,以减轻模型最终实际量化时(通常在微调后的服务步骤中)不可避免的量化降级。TorchAO 的 QAT 支持已成功用于最近发布的 [Llama-3.2 量化 1B/3B](https://ai.meta.com/blog/meta-llama-quantized-lightweight-models/) 和 [LlamaGuard-3-8B](https://github.com/meta-llama/PurpleLlama/blob/main/Llama-Guard3/8B/MODEL_CARD.md) 模型,以提高量化模型的质量。

TorchAO 的 QAT 支持涉及两个独立的步骤:准备和转换。准备步骤在训练期间“伪”量化激活和/或权重,这意味着高精度值(例如 bf16)映射到其对应的量化值,而无需实际将其转换为目标低精度 dtype(例如 int4)。转换步骤在训练后应用,将模型中的“伪”量化操作替换为执行 dtype 转换的“真实”量化

_images/qat_diagram.png

使用 TorchAO 的 QAT 进行微调有多种选择

  1. 使用我们与 [TorchTune](https://github.com/pytorch/torchtune) 的集成

  2. 使用我们与 [Axolotl](https://github.com/axolotl-ai-cloud/axolotl) 的集成

  3. 直接使用我们的 QAT API 和您自己的训练循环

选项 1:TorchTune QAT 集成

TorchAO 的 QAT 支持已集成到 TorchTune 的分布式微调配方中。用户可以运行以下等效命令,而不是以下不使用 QAT 的完全分布式微调命令:

# Regular fine-tuning without QAT
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3_2/3B_full batch_size=16

用户可以改为运行以下等效命令。请注意,指定量化器是可选的。

# Fine-tuning with QAT, by default:
#   activations are fake quantized to asymmetric per token int8
#   weights are fake quantized to symmetric per group int4
#   configurable through "quantizer._component_" in the command
tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full batch_size=16

微调后,用户可以按如下方式量化和评估结果模型。无论微调过程中是否使用 QAT,这都相同

# Quantize model weights to int4
tune run quantize --config quantization \
    model._component_=torchtune.models.llama3_2.llama3_2_3b \
    checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
    'checkpointer.checkpoint_files=[model-00001-of-00002.safetensors,model-00002-of-00002.safetensors]' \
    checkpointer.model_type=LLAMA3 \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32

# Evaluate the int4 model on hellaswag and wikitext
tune run eleuther_eval --config eleuther_evaluation \
    batch_size=1 \
    'tasks=[hellaswag, wikitext]' \
    model._component_=torchtune.models.llama3_2.llama3_2_3b \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    'checkpointer.checkpoint_files=[model-00001-of-00002-8da4w.ckpt]' \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32

微调后应打印以下内容

|  Tasks  |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|---------|------:|------|------|--------|---|-----:|---|-----:|
|hellaswag|      1|none  |None  |acc     |↑  |0.5021|±  |0.0050|
|         |       |none  |None  |acc_norm|↑  |0.6797|±  |0.0047|

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.6965|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.6206|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |13.2199|±  |   N/A|

您可以将这些值与使用和不使用 QAT 的情况进行比较,看看 QAT 在多大程度上帮助减轻了量化降级!例如,当在 [OpenAssistant Conversations (OASST1)](https://hugging-face.cn/datasets/OpenAssistant/oasst1) 数据集上微调 Llama-3.2-3B 时,我们发现量化模型使用 QAT 比不使用 QAT 实现了 3.4% 的更高准确率,恢复了 69.8% 的总体量化降级

_images/qat_eval.png

除了上述示例中的普通 QAT,TorchAO 的 QAT 还可以与 LoRA 组合使用,从而实现 [1.89 倍的训练加速](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700)并降低 36.1% 的内存使用。这在 TorchTune 的 [QAT + LoRA 微调配方](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py)中实现,可以使用以下命令运行

# Fine-tuning with QAT + LoRA
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora batch_size=16

有关如何在 TorchTune 中设置 QAT 的更多详细信息,请参阅[本教程](https://docs.pytorch.ac.cn/torchtune/main/tutorials/qat_finetune.html)。

选项 2:Axolotl QAT 集成

Axolotl 最近也添加了一个利用 TorchAO 的 QAT 支持的 QAT 微调配方。要开始,请尝试使用以下命令通过 QAT 微调 Llama-3.2-3B

axolotl train examples/llama-3/3b-qat-fsdp2.yaml
# once training is complete, perform the quantization step

axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml
# you should now have a quantized model saved in ./outputs/qat_out/quatized

请参阅 [Axolotl QAT 文档](https://docs.axolotl.ai/docs/qat.html)了解完整详细信息。

选项 3:TorchAO QAT API

如果您更喜欢使用不同的训练框架或自己的自定义训练循环,您可以直接调用 TorchAO 的 QAT API 在微调前转换模型。这些 API 是 TorchTune 和 Axolotl QAT 集成在后台调用的。

在此示例中,我们将在单个 GPU 上微调 Llama3 的迷你版本

import torch
from torchtune.models.llama3 import llama3

# Set up a smaller version of llama3 to fit in a single A100 GPU
# For smaller GPUs, adjust the model attributes accordingly
def get_model():
    return llama3(
        vocab_size=4096,
        num_layers=16,
        num_heads=16,
        num_kv_heads=4,
        embed_dim=2048,
        max_seq_len=2048,
    ).cuda()

# Example training loop
def train_loop(m: torch.nn.Module):
    optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
    loss_fn = torch.nn.CrossEntropyLoss()
    for i in range(10):
        example = torch.randint(0, 4096, (2, 16)).cuda()
        target = torch.randn((2, 16, 4096)).cuda()
        output = m(example)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

接下来,运行准备步骤,该步骤对模型进行伪量化。在此示例中,我们使用 int8 每令牌动态激活和 int4 对称每组权重作为我们的量化方案。请注意,尽管我们正在针对较低的整数精度,但训练仍以较高的浮点精度 (float32) 执行算术运算,因为我们并未实际转换伪量化值。

from torchao.quantization import (
    quantize_,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    IntXQuantizationAwareTrainingConfig,
)
model = get_model()

# prepare: insert fake quantization ops
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
quantize_(model, qat_config)

# fine-tune
train_loop(model)

微调后,我们得到一个原始高精度模型。这个微调后的模型与原始模型具有完全相同的结构。唯一的区别是 QAT 微调后的模型的权重更适合量化,这在推理期间会很有利。下一步是实际量化模型

from torchao.quantization import (
    Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
    FromIntXQuantizationAwareTrainingConfig,
)

# convert: transform fake quantization ops into actual quantized ops
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
# quantized activation and weight tensor subclasses
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))

现在我们的模型已准备好进行服务,并且通常会比在微调期间未应用准备步骤(伪量化)时具有更高的量化准确性。

有关使用 TorchAO QAT API 的完整详细信息,请参阅 [QAT 自述文件](https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md)。

替代旧版 API

上述 quantize_ API 是使用 TorchAO QAT 的推荐流程。我们还为特定量化方案提供替代的旧版“量化器”API,但与上述示例不同,这些 API 不可自定义。

from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32)

# prepare: insert fake quantization ops
# swaps `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear`
model = qat_quantizer.prepare(model)

# train
train_loop(model)

# convert: transform fake quantization ops into actual quantized ops
# swaps `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`
model = qat_quantizer.convert(model)

量化低秩适应 (QLoRA)

(即将推出!)

Float8 量化微调

(即将推出!)

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源