数值精度#
创建日期:2021 年 10 月 13 日 | 最后更新日期:2024 年 9 月 24 日
在现代计算机中,浮点数使用 IEEE 754 标准表示。有关浮点算术和 IEEE 754 标准的更多详细信息,请参阅浮点算术。特别需要注意的是,浮点数提供有限的精度(单精度浮点数大约 7 位小数,双精度浮点数大约 16 位小数),并且浮点加法和乘法不具有结合律,因此操作顺序会影响结果。因此,PyTorch 不保证对数学上相同的浮点计算产生逐位相同的结果。同样,在不同的 PyTorch 版本、不同的提交或不同的平台之间,也不保证逐位相同的结果。特别地,即使对于逐位相同的输入,即使在控制了随机性来源之后,CPU 和 GPU 的结果也可能不同。
批处理计算或切片计算#
PyTorch 中的许多操作支持批处理计算,即对输入的批处理元素执行相同的操作。一个例子是 torch.mm()
和 torch.bmm()
。虽然可以将批处理计算实现为批处理元素的循环,并对单个批处理元素应用必要的数学操作,但出于效率原因,我们不这样做,通常对整个批处理执行计算。在这种情况下,我们调用的数学库以及 PyTorch 内部操作实现可能会产生与非批处理计算略有不同的结果。特别是,假设 A
和 B
是维度适合批处理矩阵乘法的 3D 张量。那么 (A@B)[0]
(批处理结果的第一个元素)不保证与 A[0]@B[0]
(输入批处理的第一个元素的矩阵乘积)逐位相同,尽管数学上它是相同的计算。
类似地,应用于张量切片的操作不保证产生与应用于整个张量的相同操作的结果切片相同的结果。例如,设 A
是一个二维张量。A.sum(-1)[0]
不保证与 A[:,0].sum()
逐位相等。
极值#
当输入包含的值很大,使得中间结果可能超出所用数据类型的范围时,即使最终结果可以在原始数据类型中表示,它也可能溢出。例如:
import torch
a=torch.tensor([1e20, 1e20]) # fp32 type by default
a.norm() # produces tensor(inf)
a.double().norm() # produces tensor(1.4142e+20, dtype=torch.float64), representable in fp32
线性代数 (torch.linalg
)#
非有限值#
torch.linalg
使用的外部库(后端)不保证在输入包含 inf
或 NaN
等非有限值时的行为。因此,PyTorch 也不作此保证。这些操作可能会返回包含非有限值的张量,或者引发异常,甚至导致段错误。
在调用这些函数之前,考虑使用 torch.isfinite()
来检测这种情况。
线性代数中的极值#
torch.linalg
中的函数比其他 PyTorch 函数有更多的极值。
求解器和逆假设输入矩阵 A
是可逆的。如果它接近不可逆(例如,如果它有一个非常小的奇异值),那么这些算法可能会默默地返回不正确的结果。这些矩阵被称为病态的。如果提供病态输入,这些函数的结果在使用不同设备或通过关键字 driver
使用不同后端时可能会有所不同。
像 svd
、eig
和 eigh
这样的谱操作,当它们的输入具有彼此接近的奇异值时,也可能返回不正确的结果(并且它们的梯度可能无限大)。这是因为用于计算这些分解的算法很难收敛于这些输入。
在 float64
中运行计算(NumPy 默认这样做)通常会有所帮助,但它并不能在所有情况下解决这些问题。通过 torch.linalg.svdvals()
分析输入的谱或通过 torch.linalg.cond()
分析它们的条件数可能有助于检测这些问题。
Nvidia Ampere(及更高版本)设备上的 TensorFloat-32 (TF32)#
在 Ampere(及更高版本)Nvidia GPU 上,PyTorch 可以使用 TensorFloat32 (TF32) 来加速计算密集型操作,特别是矩阵乘法和卷积。当使用 TF32 张量核心执行操作时,只读取输入尾数的前 10 位。这可能会降低精度并产生令人惊讶的结果(例如,将矩阵乘以单位矩阵可能会产生与输入不同的结果)。默认情况下,TF32 张量核心对于矩阵乘法是禁用的,对于卷积是启用的,尽管大多数神经网络工作负载在使用 TF32 时与使用 fp32 具有相同的收敛行为。如果您的网络不需要完全的 float32 精度,我们建议通过 torch.backends.cuda.matmul.allow_tf32 = True
为矩阵乘法启用 TF32 张量核心。如果您的网络对矩阵乘法和卷积都需要完全的 float32 精度,那么 TF32 张量核心也可以通过 torch.backends.cudnn.allow_tf32 = False
禁用。
欲了解更多信息,请参阅TensorFloat32。
FP16 和 BF16 GEMM 的降低精度归约#
半精度 GEMM 操作通常以单精度进行中间累加(归约),以提高数值精度并增强抗溢出能力。为了性能,某些 GPU 架构,特别是较新的架构,允许将中间累加结果截断为降低的精度(例如,半精度)。这种改变从模型收敛的角度来看通常是良性的,但可能导致意外结果(例如,最终结果应该可以在半精度中表示,却出现 inf
值)。如果降低精度归约存在问题,可以通过 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
将其关闭。
BF16 GEMM 操作也有类似的标志,默认情况下是开启的。如果 BF16 降低精度归约存在问题,可以通过 torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
将其关闭。
欲了解更多信息,请参阅allow_fp16_reduced_precision_reduction 和 allow_bf16_reduced_precision_reduction。
缩放点积注意力 (SDPA) 中 FP16 和 BF16 的降低精度归约#
当使用 FP16/BF16 输入时,一个朴素的 SDPA 数学后端会因为使用低精度中间缓冲区而累积显著的数值误差。为了缓解这个问题,现在的默认行为是将 FP16/BF16 输入向上转换为 FP32。计算在 FP32/TF32 中执行,然后将最终的 FP32 结果向下转换回 FP16/BF16。这将提高数学后端使用 FP16/BF16 输入时最终输出的数值精度,但会增加内存使用,并可能导致数学后端因计算从 FP16/BF16 BMM 转移到 FP32/TF32 BMM/Matmul 而导致性能下降。
对于偏爱降低精度归约以提高速度的场景,可以通过以下设置启用它们:torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
AMD Instinct MI200 设备上的降低精度 FP16 和 BF16 GEMM 和卷积#
在 AMD Instinct MI200 GPU 上,FP16 和 BF16 V_DOT2 和 MFMA 矩阵指令将输入和输出的非规范值刷新为零。FP32 和 FP64 MFMA 矩阵指令不会将输入和输出的非规范值刷新为零。受影响的指令仅由 rocBLAS (GEMM) 和 MIOpen (卷积) 内核使用;所有其他 PyTorch 操作都不会遇到此行为。所有其他支持的 AMD GPU 也不会遇到此行为。
rocBLAS 和 MIOpen 为受影响的 FP16 操作提供了替代实现。BF16 操作没有提供替代实现;BF16 数字具有比 FP16 数字更大的动态范围,因此不太可能遇到非规范值。对于 FP16 替代实现,FP16 输入值被转换为中间 BF16 值,然后在累积 FP32 操作后转换回 FP16 输出。通过这种方式,输入和输出类型保持不变。
当使用 FP16 精度进行训练时,某些模型可能因 FP16 非规范值刷新为零而无法收敛。非规范值在训练的反向传播过程中,在梯度计算时更频繁地出现。PyTorch 默认会在反向传播过程中使用 rocBLAS 和 MIOpen 的替代实现。可以使用环境变量 ROCBLAS_INTERNAL_FP16_ALT_IMPL 和 MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL 覆盖默认行为。这些环境变量的行为如下:
前向 |
反向 |
|
---|---|---|
未设置环境变量 |
原始 |
替代 |
环境变量设置为 1 |
替代 |
替代 |
环境变量设置为 0 |
原始 |
原始 |
以下是可能使用 rocBLAS 的操作列表:
torch.addbmm
torch.addmm
torch.baddbmm
torch.bmm
torch.mm
torch.nn.GRUCell
torch.nn.LSTMCell
torch.nn.Linear
torch.sparse.addmm
以下 torch._C._ConvBackend 实现
slowNd
slowNd_transposed
slowNd_dilated
slowNd_dilated_transposed
以下是可能使用 MIOpen 的操作列表:
torch.nn.Conv[Transpose]Nd
以下 torch._C._ConvBackend 实现
ConvBackend::Miopen
ConvBackend::MiopenDepthwise
ConvBackend::MiopenTranspose