• 文档 >
  • (第 1 部分)使用 float8 进行预训练
快捷方式

(第一部分)使用 float8 进行预训练

TorchAO 通过利用我们的量化和稀疏技术并将其集成到我们的合作伙伴框架中,提供端到端的预训练、微调和模型服务优化流程。这是展示此端到端流程的三个教程中的第一部分,重点关注预训练步骤。

_images/e2e_flow_part1.png

使用 torchao 进行 float8 预训练,在 512 个 GPU 集群上可提供高达 1.5 倍的速度提升,在使用最新 torchao.float8 行式配方的 2K H200 集群上可提供高达 1.34-1.43 倍的速度提升

在本教程中,我们将展示两种使用 torchao.float8 配方进行预训练的方法

  1. 使用 torchtitan 进行预训练,这是 PyTorch 官方预训练框架,与 torchao 原生集成。

  2. 直接使用 torchao 进行预训练,将 torchao 的 float8 训练配方集成到您自己的预训练代码中。

使用 torchtitan 进行预训练

在本教程中,我们将使用 torchtitan 和 torchao 的 float8 训练配方(行式缩放和张量式缩放)预训练 Llama3-8B。

Torchtitan 是 PyTorch 的官方预训练框架,与 torchao 原生集成,支持多种流行的旗舰模型,并具有常见的并行形式、float8 训练、分布式检查点等功能。有关更多详细信息,请参阅 torchtitan 文档

您可以使用此工作流程快速开始,体验“开箱即用”。用户通常会 fork torchtitan 并在其基础上进行开发。

先决条件

  1. (推荐)使用 conda 或 venv 创建一个新的虚拟环境。

  2. 安装 torchao.

  3. 安装 torchtitan,包括“下载分词器”步骤。

您现在已准备好使用以下配方之一启动预训练任务!

行式缩放

从 torchtitan 根目录运行以下命令,在 8 个 GPU 上启动 Llama3-8B 训练任务,并进行 float8 行式训练

NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --model.converters="float8" --float8.recipe_name="rowwise"

当使用超过 1 个 GPU 时,Torchtitan 将自动使用 FSDP2 进行训练并行化。要使用其他并行形式、修改超参数或更改其他训练配置,您可以直接编辑 llama3_8b.toml 文件或使用命令行标志(运行命令时带 --help 查看更多选项)。

您应该会看到类似以下的终端输出

[rank0]:[titan] 2025-06-04 08:51:48,074 - root - INFO - step:  1  loss: 12.2254  memory: 27.34GiB(28.78%)  tps: 375  tflops: 21.73  mfu: 2.20%
[rank0]:[titan] 2025-06-04 08:51:58,557 - root - INFO - step: 10  loss: 10.7069  memory: 30.99GiB(32.62%)  tps: 7,034  tflops: 407.35  mfu: 41.19%
[rank0]:[titan] 2025-06-04 08:52:10,224 - root - INFO - step: 20  loss:  8.9196  memory: 30.99GiB(32.62%)  tps: 7,022  tflops: 406.65  mfu: 41.12%
[rank0]:[titan] 2025-06-04 08:52:21,904 - root - INFO - step: 30  loss:  8.1423  memory: 30.99GiB(32.62%)  tps: 7,014  tflops: 406.23  mfu: 41.08%

如您所见,忽略预热步骤,我们实现了约 7k TPS 的吞吐量,峰值内存使用量为 30.99GB。要与 bfloat16 训练进行性能比较,您可以删除 --model.converters="float8" --float8.recipe_name="rowwise" 标志,并运行相同的命令以查看 bfloat16 训练的基准性能

NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile

您应该会看到以下输出

[rank0]:[titan] 2025-06-04 11:02:37,404 - root - INFO - step:  1  loss: 12.2611  memory: 27.22GiB(28.65%)  tps: 595  tflops: 34.47  mfu: 3.49%
[rank0]:[titan] 2025-06-04 11:02:49,027 - root - INFO - step: 10  loss: 10.4260  memory: 30.89GiB(32.51%)  tps: 6,344  tflops: 367.39  mfu: 37.15%
[rank0]:[titan] 2025-06-04 11:03:01,988 - root - INFO - step: 20  loss:  8.9482  memory: 30.89GiB(32.51%)  tps: 6,321  tflops: 366.06  mfu: 37.01%
[rank0]:[titan] 2025-06-04 11:03:14,991 - root - INFO - step: 30  loss:  8.1183  memory: 30.89GiB(32.51%)  tps: 6,300  tflops: 364.89  mfu: 36.89%
[rank0]:[titan] 2025-06-04 11:03:28,013 - root - INFO - step: 40  loss:  7.4659  memory: 30.89GiB(32.51%)  tps: 6,291  tflops: 364.36  mfu: 36.84%
[rank0]:[titan] 2025-06-04 11:03:39,769 - root - INFO - [GC] Peforming periodical GC collection. 0.02 seconds.

如您所见,bfloat16 基准实现了约 6.3k TPS 的吞吐量,峰值内存使用量为 30.89GB。

这意味着我们的 float8 行式缩放配方与 bfloat16 基准相比,实现了 1.11 倍的更高吞吐量,而峰值内存几乎相同!

请注意,使用张量式缩放配方可以实现更高的吞吐量改进,这在性能与精度曲线上的位置不同。

张量式缩放

使用张量式缩放的 Float8 训练是默认配方,因此我们可以省略 --float8.recipe_name 标志

NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --model.converters="float8"

您应该会看到类似以下的输出

[rank0]:[titan] 2025-06-04 10:52:19,648 - root - INFO - step:  1  loss: 12.2648  memory: 27.28GiB(28.71%)  tps: 557  tflops: 32.29  mfu: 3.26%
[rank0]:[titan] 2025-06-04 10:52:29,475 - root - INFO - step: 10  loss: 10.9106  memory: 30.91GiB(32.53%)  tps: 7,503  tflops: 434.53  mfu: 43.94%
[rank0]:[titan] 2025-06-04 10:52:40,166 - root - INFO - step: 20  loss:  9.0774  memory: 30.91GiB(32.53%)  tps: 7,663  tflops: 443.78  mfu: 44.87%
[rank0]:[titan] 2025-06-04 10:52:50,885 - root - INFO - step: 30  loss:  8.3233  memory: 30.91GiB(32.53%)  tps: 7,643  tflops: 442.66  mfu: 44.76%
[rank0]:[titan] 2025-06-04 10:53:01,613 - root - INFO - step: 40  loss:  7.6150  memory: 30.91GiB(32.53%)  tps: 7,637  tflops: 442.27  mfu: 44.72%

如您所见,我们实现了约 7.6k TPS 的吞吐量,峰值内存使用量为 30.91GB,与 bfloat16 基准相比,吞吐量高出 1.21 倍

选择配方

总结:行式缩放更适合优先考虑更准确的数值和训练稳定性的任务,而张量式缩放更适合优先考虑训练吞吐量的任务。

张量式缩放的更高吞吐量是以略高的量化误差(即,相对于 bfloat16 降低数值完整性)为代价的,相比于行式缩放。这是因为行式缩放使用更细粒度的缩放因子(每行而非每张量),这限制了可能导致缩放期间下溢的异常值的影响。

下面您可以看到 Llama3-8B 在 8xH100 GPU 上训练时,bfloat16、float8 张量式和 float8 行式训练的损失曲线比较

Loss curves for training Llama3-8B on 8xH100s with torchtitan using bfloat16, float8 tensorwise, and float8 rowwise training.

重要注意事项

  • 目前 torchtitan 中只支持 2 个或更多 GPU 的 float8 训练,不支持单 GPU 训练。

  • 您必须使用 --training.compile 才能实现高性能。torchao float8 训练配方原生构建于 torch.compile 之上,因此开箱即用!

直接使用 torchao 进行预训练

在本教程中,我们将直接使用 torchao API 预训练一个玩具模型。

您可以使用此工作流程将 torchao 直接集成到您自己的自定义预训练代码中。

先决条件

  1. (推荐)使用 conda 或 venv 创建一个新的虚拟环境。

  2. 安装 torchao.

您现在已准备好将 torchao 直接集成到您的训练代码中!

模型转换 API

用于将模型转换为使用 float8 训练的 torchao API 是:convert_to_float8_training。此 API 将递归地将模型中的 nn.Linear 模块转换为使用 Float8Linear

您可以使用 module_filter_fn 参数来确定哪些 nn.Linear 层应转换为使用 Float8Linear

您应该参考此性能基准表,以了解对于给定 GEMM 大小,您可以期望相对于 bfloat16 获得何种性能改进。

下面是一个展示如何使用的代码片段

import torch
from torch import nn
import torch.nn.functional as F

from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_linear import Float8Linear
from torchao.float8 import convert_to_float8_training
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
    raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")

# create model and sample input
m = nn.Sequential(
    nn.Linear(2048, 4096),
    nn.Linear(4096, 128),
    nn.Linear(128, 1),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
    # don't convert the last module
    if fqn == "1":
        return False
    # don't convert linear modules with weight dimensions not divisible by 16
    if isinstance(mod, torch.nn.Linear):
        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
            return False
    return True

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# enable torch.compile for competitive performance
m = torch.compile(m)

# toy training loop
for _ in range(10):
    optimizer.zero_grad()
    output = m(x)
    # use fake labels for demonstration purposes
    fake_labels = torch.ones_like(output)
    loss = F.mse_loss(output, fake_labels)
    loss.backward()
    optimizer.step()

# save the model
torch.save({
    'model': m,
    'model_state_dict': m.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')

在预训练模型后,您可以选择对其进行微调,以适应更特定领域的数据集,并为最终服务时的量化做准备。在本教程的下一部分中,我们将探讨微调步骤中的几种模型优化选项。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源