自动混合精度¶
Pytorch/XLA 的 AMP 通过支持 XLA:GPU 和 XLA:TPU 设备上的自动混合精度来扩展 Pytorch 的 AMP 包。AMP 用于通过以 float32 精度执行某些操作,而以较低精度数据类型(float16 或 bfloat16,取决于硬件支持)执行其他操作来加速训练和推理。本文档介绍了如何在 XLA 设备上使用 AMP 以及最佳实践。
XLA:TPU 的 AMP¶
TPU 上的 AMP 会自动将操作转换为以 float32 或 bfloat16 精度运行,因为 TPU 原生支持 bfloat16。下面是一个简单的 TPU AMP 示例
from torch_xla.amp import syncfree
import torch_xla.core.xla_model as xm
# Creates model and optimizer in default precision
model = Net().to('xla')
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass
with autocast(torch_xla.device()):
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward()
loss.backward()
xm.optimizer_step.(optimizer)
autocast(torch_xla.device()) 是 torch.autocast('xla') 的别名,当 XLA 设备是 TPU 时。或者,如果一个脚本只用于 TPU,那么可以直接使用 torch.autocast('xla', dtype=torch.bfloat16)。
如果存在应该被自动转换但未包含的操作,请提交一个 issue 或 pull request。
XLA:TPU 的 AMP 最佳实践¶
autocast应该只包装网络的前向传播和损失计算。反向传播操作将以与相应前向传播操作相同的精度运行。由于 TPU 使用 bfloat16 混合精度,因此不需要梯度缩放。
Pytorch/XLA 提供了修改后的 优化器 版本,避免了设备和主机之间的额外同步。
支持的操作¶
TPU 上的 AMP 的工作方式类似于 PyTorch 的 AMP。自动转换规则总结如下:
只有非原地操作和 Tensor 方法才有资格进行自动转换。原地变体和显式提供 out=… Tensor 的调用在启用了自动转换的区域中是允许的,但不会经过自动转换。例如,在启用了自动转换的区域中,a.addmm(b, c) 可以进行自动转换,但 a.addmm_(b, c) 和 a.addmm(b, c, out=d) 则不能。为了获得最佳性能和稳定性,请在启用了自动转换的区域中优先使用非原地操作。
以 float64 或非浮点数据类型运行的操作不符合资格,无论是否启用了自动转换,它们都将以这些类型运行。此外,使用显式 dtype=… 参数调用的操作不符合资格,并且会产生遵循 dtype 参数的输出。
未在下面列出的操作不会经过自动转换。它们将以其输入的类型运行。如果未列出的操作是自动转换操作的下游,自动转换仍可能更改它们运行的类型。
自动转换为 bfloat16 的操作
__matmul__, addbmm, addmm, addmv, addr, baddbmm,bmm, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, linear, matmul, mm, relu, prelu, max_pool2d
自动转换为 float32 的操作
batch_norm, log_softmax, binary_cross_entropy, binary_cross_entropy_with_logits, prod, cdist, trace, chloesky ,inverse, reflection_pad, replication_pad, mse_loss, cosine_embbeding_loss, nll_loss, multilabel_margin_loss, qr, svd, triangular_solve, linalg_svd, linalg_inv_ex
自动转换为最宽输入类型的操作
stack, cat, index_copy
XLA:GPU 的 AMP¶
XLA:GPU 设备上的 AMP 重用了 PyTorch 的 AMP 规则。有关 CUDA 的特定行为,请参阅 PyTorch 的 AMP 文档。下面是一个简单的 CUDA AMP 示例
from torch_xla.amp import syncfree
import torch_xla.core.xla_model as xm
# Creates model and optimizer in default precision
model = Net().to('xla')
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)
scaler = GradScaler()
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass
with autocast(torch_xla.device()):
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward pass
scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xr.world_size())
scaler.step(optimizer)
scaler.update()
当 XLA 设备是 CUDA 设备(XLA:GPU)时,autocast(torch_xla.device()) 是 torch.cuda.amp.autocast() 的别名。或者,如果一个脚本只用于 CUDA 设备,那么可以直接使用 torch.cuda.amp.autocast,但这要求 torch 已编译为支持 cuda,并且支持 torch.bfloat16 数据类型。我们建议在 XLA:GPU 上使用 autocast(torch_xla.device()),因为它不需要 torch.cuda 支持任何数据类型,包括 torch.bfloat16。
XLA:GPU 的 AMP 最佳实践¶
autocast应该只包装网络的前向传播和损失计算。反向传播操作将以与相应前向传播操作相同的精度运行。在使用 Cuda 设备上的 AMP 时,请勿设置
XLA_USE_F16标志。这将覆盖 AMP 提供的每个运算符的精度设置,并导致所有运算符以 float16 精度执行。使用梯度缩放来防止 float16 梯度下溢。
Pytorch/XLA 提供了修改后的 优化器 版本,避免了设备和主机之间的额外同步。
示例¶
我们的 mnist 训练脚本 和 imagenet 训练脚本 演示了 AMP 如何在 TPU 和 GPU 上使用。