自动混合精度包 - torch.amp#
创建日期:2025 年 6 月 12 日 | 最后更新日期:2025 年 6 月 12 日
torch.amp
提供了混合精度的便捷方法,其中一些操作使用 torch.float32
(float
)数据类型,而其他操作使用较低精度的浮点数据类型(lower_precision_fp
):torch.float16
(half
)或 torch.bfloat16
。一些操作,如线性层和卷积,在 lower_precision_fp
中速度更快。其他操作,如归约,通常需要 float32
的动态范围。混合精度试图将每个操作与其适当的数据类型相匹配。
通常,使用 torch.float16
数据类型的“自动混合精度训练”会一起使用 torch.autocast
和 torch.amp.GradScaler
,如 自动混合精度示例 和 自动混合精度食谱 中所示。然而,torch.autocast
和 torch.GradScaler
是模块化的,如果需要,可以单独使用。如 torch.autocast
的 CPU 示例部分所示,“CPU 上的自动混合精度训练/推理”使用 torch.bfloat16
数据类型,只使用 torch.autocast
。
警告
torch.cuda.amp.autocast(args...)
和 torch.cpu.amp.autocast(args...)
已弃用。请改用 torch.amp.autocast("cuda", args...)
或 torch.amp.autocast("cpu", args...)
。torch.cuda.amp.GradScaler(args...)
和 torch.cpu.amp.GradScaler(args...)
已弃用。请改用 torch.amp.GradScaler("cuda", args...)
或 torch.amp.GradScaler("cpu", args...)
。
torch.autocast
和 torch.cpu.amp.autocast
在版本 1.10
中是新加入的。
自动类型转换#
- torch.amp.autocast_mode.is_autocast_available(device_type)[source]#
返回一个布尔值,指示
device_type
上是否可用自动类型转换。- 参数
device_type (str) – 要使用的设备类型。可能的值包括:“cuda”、“cpu”、“mtia”、“maia”、“xpu”等。该类型与
torch.device
的 type 属性相同。因此,您可以使用 Tensor.device.type 获取张量的设备类型。- 返回类型
- class torch.autocast(device_type, dtype=None, enabled=True, cache_enabled=None)[source]#
autocast
的实例用作上下文管理器或装饰器,允许您的脚本的某些区域以混合精度运行。在这些区域中,操作以自动类型转换选择的特定于操作的数据类型运行,以提高性能同时保持准确性。有关详细信息,请参阅 Autocast 操作参考。
进入启用自动类型转换的区域时,张量可以是任何类型。在使用自动类型转换时,不应在模型或输入上调用
half()
或bfloat16()
。autocast
应仅包装网络的正向传播(包括损失计算)。不建议在自动类型转换下进行反向传播。反向操作以与自动类型转换用于相应正向操作相同的类型运行。CUDA 设备示例
# Creates model and optimizer in default precision model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass (model + loss) with torch.autocast(device_type="cuda"): output = model(input) loss = loss_fn(output, target) # Exits the context manager before backward() loss.backward() optimizer.step()
有关更复杂的场景(例如,梯度惩罚、多个模型/损失、自定义 autograd 函数)中的用法(以及梯度缩放),请参阅 自动混合精度示例。
autocast
也可以用作装饰器,例如,在模型的forward
方法上class AutocastModel(nn.Module): ... @torch.autocast(device_type="cuda") def forward(self, input): ...
在启用自动类型转换的区域中生成的浮点张量可能是
float16
。返回到禁用自动类型转换的区域后,将它们与不同数据类型的浮点张量一起使用可能会导致类型不匹配错误。如果是这种情况,请将从自动类型转换区域生成的张量转换回float32
(或所需的其他数据类型)。如果来自自动类型转换区域的张量已经是float32
,则转换操作无效,不会产生额外开销。CUDA 示例# Creates some tensors in default dtype (here assumed to be float32) a_float32 = torch.rand((8, 8), device="cuda") b_float32 = torch.rand((8, 8), device="cuda") c_float32 = torch.rand((8, 8), device="cuda") d_float32 = torch.rand((8, 8), device="cuda") with torch.autocast(device_type="cuda"): # torch.mm is on autocast's list of ops that should run in float16. # Inputs are float32, but the op runs in float16 and produces float16 output. # No manual casts are required. e_float16 = torch.mm(a_float32, b_float32) # Also handles mixed input types f_float16 = torch.mm(d_float32, e_float16) # After exiting autocast, calls f_float16.float() to use with d_float32 g_float32 = torch.mm(d_float32, f_float16.float())
CPU 训练示例
# Creates model and optimizer in default precision model = Net() optimizer = optim.SGD(model.parameters(), ...) for epoch in epochs: for input, target in data: optimizer.zero_grad() # Runs the forward pass with autocasting. with torch.autocast(device_type="cpu", dtype=torch.bfloat16): output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step()
CPU 推理示例
# Creates model in default precision model = Net().eval() with torch.autocast(device_type="cpu", dtype=torch.bfloat16): for input in data: # Runs the forward pass with autocasting. output = model(input)
带 Jit Trace 的 CPU 推理示例
class TestModel(nn.Module): def __init__(self, input_size, num_classes): super().__init__() self.fc1 = nn.Linear(input_size, num_classes) def forward(self, x): return self.fc1(x) input_size = 2 num_classes = 2 model = TestModel(input_size, num_classes).eval() # For now, we suggest to disable the Jit Autocast Pass, # As the issue: https://github.com/pytorch/pytorch/issues/75956 torch._C._jit_set_autocast_mode(False) with torch.cpu.amp.autocast(cache_enabled=False): model = torch.jit.trace(model, torch.randn(1, input_size)) model = torch.jit.freeze(model) # Models Run for _ in range(3): model(torch.randn(1, input_size))
在启用自动类型转换的区域中发生的类型不匹配错误是一个 bug;如果您遇到这种情况,请提交问题。
autocast(enabled=False)
子区域可以嵌套在启用自动类型转换的区域中。局部禁用自动类型转换可能很有用,例如,如果您想强制一个子区域以特定dtype
运行。禁用自动类型转换可让您显式控制执行类型。在子区域中,应在内部使用前将来自周围区域的输入转换为dtype
。# Creates some tensors in default dtype (here assumed to be float32) a_float32 = torch.rand((8, 8), device="cuda") b_float32 = torch.rand((8, 8), device="cuda") c_float32 = torch.rand((8, 8), device="cuda") d_float32 = torch.rand((8, 8), device="cuda") with torch.autocast(device_type="cuda"): e_float16 = torch.mm(a_float32, b_float32) with torch.autocast(device_type="cuda", enabled=False): # Calls e_float16.float() to ensure float32 execution # (necessary because e_float16 was created in an autocasted region) f_float32 = torch.mm(c_float32, e_float16.float()) # No manual casts are required when re-entering the autocast-enabled region. # torch.mm again runs in float16 and produces float16 output, regardless of input types. g_float16 = torch.mm(d_float32, f_float32)
自动类型转换状态是线程本地的。如果您希望在新线程中启用它,则必须在该线程中调用上下文管理器或装饰器。这会影响
torch.nn.DataParallel
和torch.nn.parallel.DistributedDataParallel
在用于多个 GPU(每个进程)时(请参阅 使用多个 GPU)。- 参数
device_type (str, required) – 要使用的设备类型。可能的值包括:“cuda”、“cpu”、“mtia”、“maia”、“xpu”和“hpu”。该类型与
torch.device
的 type 属性相同。因此,您可以使用 Tensor.device.type 获取张量的设备类型。enabled (bool, optional) – 是否应在区域中启用自动类型转换。默认值:
True
dtype (torch_dtype, optional) – 在自动类型转换中运行的操作的数据类型。如果
dtype
为None
,它将使用get_autocast_dtype()
提供的默认值(CUDA 为torch.float16
,CPU 为torch.bfloat16
)。默认值:None
cache_enabled (bool, optional) – 是否应启用自动类型转换内部的权重缓存。默认值:
True
- torch.amp.custom_fwd(fwd=None, *, device_type, cast_inputs=None)[source]#
创建用于自定义 autograd 函数的
forward
方法的辅助装饰器。Autograd 函数是
torch.autograd.Function
的子类。有关更多详细信息,请参阅 示例页面。- 参数
device_type (str) – 要使用的设备类型。“cuda”、“cpu”、“mtia”、“maia”、“xpu”等。该类型与
torch.device
的 type 属性相同。因此,您可以使用 Tensor.device.type 获取张量的设备类型。cast_inputs (
torch.dtype
或 None, optional, default=None) – 如果不为None
,当forward
在启用自动类型转换的区域中运行时,它会将传入的浮点张量转换为目标数据类型(非浮点张量不受影响),然后禁用自动类型转换执行forward
。如果为None
,则forward
的内部操作将与当前的自动类型转换状态一起执行。
注意
如果装饰的
forward
在自动类型转换禁用区域之外调用,custom_fwd
将无效,cast_inputs
也将不起作用。
- torch.amp.custom_bwd(bwd=None, *, device_type)[source]#
创建用于自定义 autograd 函数的 backward 方法的辅助装饰器。
Autograd 函数是
torch.autograd.Function
的子类。确保backward
以与forward
相同的自动类型转换状态执行。有关更多详细信息,请参阅 示例页面。- 参数
device_type (str) – 要使用的设备类型。“cuda”、“cpu”、“mtia”、“maia”、“xpu”等。该类型与
torch.device
的 type 属性相同。因此,您可以使用 Tensor.device.type 获取张量的设备类型。
- class torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True)[source]#
参见
torch.autocast
。torch.cuda.amp.autocast(args...)
已弃用。请改用torch.amp.autocast("cuda", args...)
。
- torch.cuda.amp.custom_fwd(fwd=None, *, cast_inputs=None)[source]#
torch.cuda.amp.custom_fwd(args...)
已弃用。请改用torch.amp.custom_fwd(args..., device_type='cuda')
。
- torch.cuda.amp.custom_bwd(bwd)[source]#
torch.cuda.amp.custom_bwd(args...)
已弃用。请改用torch.amp.custom_bwd(args..., device_type='cuda')
。
- class torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True)[source]#
参见
torch.autocast
。torch.cpu.amp.autocast(args...)
已弃用。请改用torch.amp.autocast("cpu", args...)
。
梯度缩放#
如果特定操作的正向传播具有 float16
输入,则该操作的反向传播将产生 float16
梯度。幅度小的梯度值可能无法在 float16
中表示。这些值将变为零(“下溢”),因此相应参数的更新将丢失。
为防止下溢,“梯度缩放”将网络的损失乘以一个比例因子,并对缩放后的损失调用反向传播。换句话说,梯度值具有更大的幅度,因此它们不会变为零。
在优化器更新参数之前,应取消每个参数的梯度(.grad
属性)的缩放,以免比例因子干扰学习率。
注意
AMP/fp16 可能并非适用于所有模型!例如,大多数 bf16 预训练的模型无法在最大值为 65504 的 fp16 数值范围内运行,并且会导致梯度溢出而非下溢。在这种情况下,比例因子可能会减小到 1 以下,以尝试将梯度恢复到 fp16 动态范围内可表示的数字。虽然您可能期望比例因子始终大于 1,但我们的 GradScaler 不做此保证以维持性能。如果您在使用 AMP/fp16 时遇到损失或梯度中的 NaN,请验证您的模型是否兼容。
Autocast 操作参考#
操作资格#
以 float64
或非浮点数据类型运行的操作不符合资格,并且无论是否启用自动类型转换,都将以这些类型运行。
只有非原地操作和 Tensor 方法才符合资格。在启用自动类型转换的区域中允许原地变体和显式提供 out=...
Tensor 的调用,但它们不会通过自动类型转换。例如,在启用自动类型转换的区域中,a.addmm(b, c)
可以自动类型转换,但 a.addmm_(b, c)
和 a.addmm(b, c, out=d)
不能。为了获得最佳性能和稳定性,请在启用自动类型转换的区域中优先使用非原地操作。
使用显式 dtype=...
参数调用的操作不符合资格,并且将产生尊重 dtype
参数的输出。
CUDA 操作特定行为#
以下列表描述了在启用自动类型转换的区域中符合资格的操作的行为。无论这些操作是作为 torch.nn.Module
、函数还是 torch.Tensor
方法调用,它们都会通过自动类型转换。如果函数在多个命名空间中公开,则无论命名空间如何,它们都将通过自动类型转换。
下面未列出的操作不会通过自动类型转换。它们以其输入定义的数据类型运行。但是,如果未列出的操作位于自动类型转换操作的下游,自动类型转换仍可能更改它们运行的数据类型。
如果一个操作未列出,我们假设它在 float16
中是数值稳定的。如果您认为未列出的操作在 float16
中不具有数值稳定性,请提交一个问题。
可自动转换为 float16
的 CUDA 操作#
__matmul__
, addbmm
, addmm
, addmv
, addr
, baddbmm
, bmm
, chain_matmul
, multi_dot
, conv1d
, conv2d
, conv3d
, conv_transpose1d
, conv_transpose2d
, conv_transpose3d
, GRUCell
, linear
, LSTMCell
, matmul
, mm
, mv
, prelu
, RNNCell
可自动转换为 float32
的 CUDA 操作#
__pow__
, __rdiv__
, __rpow__
, __rtruediv__
, acos
, asin
, binary_cross_entropy_with_logits
, cosh
, cosine_embedding_loss
, cdist
, cosine_similarity
, cross_entropy
, cumprod
, cumsum
, dist
, erfinv
, exp
, expm1
, group_norm
, hinge_embedding_loss
, kl_div
, l1_loss
, layer_norm
, log
, log_softmax
, log10
, log1p
, log2
, margin_ranking_loss
, mse_loss
, multilabel_margin_loss
, multi_margin_loss
, nll_loss
, norm
, normalize
, pdist
, poisson_nll_loss
, pow
, prod
, reciprocal
, rsqrt
, sinh
, smooth_l1_loss
, soft_margin_loss
, softmax
, softmin
, softplus
, sum
, renorm
, tan
, triplet_margin_loss
提升至最宽输入类型的 CUDA 操作#
这些操作不需要特定数据类型即可保持稳定性,但它们接受多个输入,并要求输入的 Dtype 匹配。如果所有输入都是 float16
,则操作以 float16
运行。如果任何输入是 float32
,则自动类型转换会将所有输入转换为 float32
并以 float32
运行操作。
addcdiv
, addcmul
, atan2
, bilinear
, cross
, dot
, grid_sample
, index_put
, scatter_add
, tensordot
此处未列出的某些操作(例如,add
等二元操作)会在自动类型转换干预之前原生提升输入。如果输入是 float16
和 float32
的混合,这些操作将以 float32
运行并产生 float32
输出,无论是否启用自动类型转换。
优先使用 binary_cross_entropy_with_logits
而非 binary_cross_entropy
#
torch.nn.functional.binary_cross_entropy()
(以及包装它的 torch.nn.BCELoss
)的反向传播会产生无法在 float16
中表示的梯度。在启用自动类型转换的区域中,正向输入可能是 float16
,这意味着反向梯度必须在 float16
中表示(将 float16
正向输入自动类型转换为 float32
无济于事,因为这种转换必须在反向传播中进行逆转)。因此,binary_cross_entropy
和 BCELoss
在启用自动类型转换的区域中会引发错误。
许多模型在二元交叉熵层之前使用 sigmoid 层。在这种情况下,结合这两个层使用 torch.nn.functional.binary_cross_entropy_with_logits()
或 torch.nn.BCEWithLogitsLoss
。binary_cross_entropy_with_logits
和 BCEWithLogits
可以安全地进行自动类型转换。
XPU 操作特定行为(实验性)#
以下列表描述了在启用自动类型转换的区域中符合资格的操作的行为。无论这些操作是作为 torch.nn.Module
、函数还是 torch.Tensor
方法调用,它们都会通过自动类型转换。如果函数在多个命名空间中公开,则无论命名空间如何,它们都将通过自动类型转换。
下面未列出的操作不会通过自动类型转换。它们以其输入定义的数据类型运行。但是,如果未列出的操作位于自动类型转换操作的下游,自动类型转换仍可能更改它们运行的数据类型。
如果一个操作未列出,我们假设它在 float16
中是数值稳定的。如果您认为未列出的操作在 float16
中不具有数值稳定性,请提交一个问题。
可自动转换为 float16
的 XPU 操作#
addbmm
, addmm
, addmv
, addr
, baddbmm
, bmm
, chain_matmul
, multi_dot
, conv1d
, conv2d
, conv3d
, conv_transpose1d
, conv_transpose2d
, conv_transpose3d
, GRUCell
, linear
, LSTMCell
, matmul
, mm
, mv
, RNNCell
可自动转换为 float32
的 XPU 操作#
__pow__
, __rdiv__
, __rpow__
, __rtruediv__
, binary_cross_entropy_with_logits
, cosine_embedding_loss
, cosine_similarity
, cumsum
, dist
, exp
, group_norm
, hinge_embedding_loss
, kl_div
, l1_loss
, layer_norm
, log
, log_softmax
, margin_ranking_loss
, nll_loss
, normalize
, poisson_nll_loss
, pow
, reciprocal
, rsqrt
, soft_margin_loss
, softmax
, softmin
, sum
, triplet_margin_loss
提升至最宽输入类型的 XPU 操作#
这些操作不需要特定数据类型即可保持稳定性,但它们接受多个输入,并要求输入的 Dtype 匹配。如果所有输入都是 float16
,则操作以 float16
运行。如果任何输入是 float32
,则自动类型转换会将所有输入转换为 float32
并以 float32
运行操作。
bilinear
, cross
, grid_sample
, index_put
, scatter_add
, tensordot
此处未列出的某些操作(例如,add
等二元操作)会在自动类型转换干预之前原生提升输入。如果输入是 float16
和 float32
的混合,这些操作将以 float32
运行并产生 float32
输出,无论是否启用自动类型转换。
CPU 操作特定行为#
以下列表描述了在启用自动类型转换的区域中符合资格的操作的行为。无论这些操作是作为 torch.nn.Module
、函数还是 torch.Tensor
方法调用,它们都会通过自动类型转换。如果函数在多个命名空间中公开,则无论命名空间如何,它们都将通过自动类型转换。
下面未列出的操作不会通过自动类型转换。它们以其输入定义的数据类型运行。但是,如果未列出的操作位于自动类型转换操作的下游,自动类型转换仍可能更改它们运行的数据类型。
如果操作未列出,我们假设它在 bfloat16
中是数值稳定的。如果您认为未列出的操作在 bfloat16
中不具有数值稳定性,请提交一个问题。float16
与 bfloat16
的列表相同。
可自动转换为 bfloat16
的 CPU 操作#
conv1d
, conv2d
, conv3d
, bmm
, mm
, linalg_vecdot
, baddbmm
, addmm
, addbmm
, linear
, matmul
, _convolution
, conv_tbc
, mkldnn_rnn_layer
, conv_transpose1d
, conv_transpose2d
, conv_transpose3d
, prelu
, scaled_dot_product_attention
, _native_multi_head_attention
可自动转换为 float32
的 CPU 操作#
avg_pool3d
, binary_cross_entropy
, grid_sampler
, grid_sampler_2d
, _grid_sampler_2d_cpu_fallback
, grid_sampler_3d
, polar
, prod
, quantile
, nanquantile
, stft
, cdist
, trace
, view_as_complex
, cholesky
, cholesky_inverse
, cholesky_solve
, inverse
, lu_solve
, orgqr
, inverse
, ormqr
, pinverse
, max_pool3d
, max_unpool2d
, max_unpool3d
, adaptive_avg_pool3d
, reflection_pad1d
, reflection_pad2d
, replication_pad1d
, replication_pad2d
, replication_pad3d
, mse_loss
, cosine_embedding_loss
, nll_loss
, nll_loss2d
, hinge_embedding_loss
, poisson_nll_loss
, cross_entropy_loss
, l1_loss
, huber_loss
, margin_ranking_loss
, soft_margin_loss
, triplet_margin_loss
, multi_margin_loss
, ctc_loss
, kl_div
, multilabel_margin_loss
, binary_cross_entropy_with_logits
, fft_fft
, fft_ifft
, fft_fft2
, fft_ifft2
, fft_fftn
, fft_ifftn
, fft_rfft
, fft_irfft
, fft_rfft2
, fft_irfft2
, fft_rfftn
, fft_irfftn
, fft_hfft
, fft_ihfft
, linalg_cond
, linalg_matrix_rank
, linalg_solve
, linalg_cholesky
, linalg_svdvals
, linalg_eigvals
, linalg_eigvalsh
, linalg_inv
, linalg_householder_product
, linalg_tensorinv
, linalg_tensorsolve
, fake_quantize_per_tensor_affine
, geqrf
, _lu_with_info
, qr
, svd
, triangular_solve
, fractional_max_pool2d
, fractional_max_pool3d
, adaptive_max_pool3d
, multilabel_margin_loss_forward
, linalg_qr
, linalg_cholesky_ex
, linalg_svd
, linalg_eig
, linalg_eigh
, linalg_lstsq
, linalg_inv_ex
提升至最宽输入类型的 CPU 操作#
这些操作不需要特定数据类型即可保持稳定性,但它们接受多个输入,并要求输入的 Dtype 匹配。如果所有输入都是 bfloat16
,则操作以 bfloat16
运行。如果任何输入是 float32
,则自动类型转换会将所有输入转换为 float32
并以 float32
运行操作。
cat
, stack
, index_copy
此处未列出的某些操作(例如,add
等二元操作)会在自动类型转换干预之前原生提升输入。如果输入是 bfloat16
和 float32
的混合,这些操作将以 float32
运行并产生 float32
输出,无论是否启用自动类型转换。