评价此页

torch.set_float32_matmul_precision#

torch.set_float32_matmul_precision(precision)[source]#

设置 float32 矩阵乘法的内部精度。

以较低的精度运行 float32 矩阵乘法可以显著提高性能,并且在某些程序中精度损失的影响可以忽略不计。

支持三种设置

  • “highest”(最高),float32 矩阵乘法在内部计算中使用 float32 数据类型(24 位尾数,23 位显式存储)。

  • “high”(高),float32 矩阵乘法使用 TensorFloat32 数据类型(10 位尾数,显式存储)或将每个 float32 数视为两个 bfloat16 数的和(约 16 位尾数,14 位显式存储),前提是存在相应的快速矩阵乘法算法。否则,float32 矩阵乘法将按“highest”精度计算。有关 bfloat16 方法的更多信息,请参阅下文。

  • “medium”(中等),float32 矩阵乘法使用 bfloat16 数据类型(8 位尾数,7 位显式存储)进行内部计算,前提是存在使用该数据类型进行内部计算的快速矩阵乘法算法。否则,float32 矩阵乘法将按“high”精度计算。

使用“high”精度时,float32 乘法可能使用一种基于 bfloat16 的算法,该算法比简单地截断到较小的尾数位数(例如,TensorFloat32 为 10 位,bfloat16 显式存储为 7 位)更复杂。有关该算法的完整描述,请参阅 [Henry2019]。在此简要解释,第一步是认识到我们可以将一个 float32 数完美地编码为三个 bfloat16 数的和(因为 float32 有 23 位尾数,而 bfloat16 有 7 位显式存储,并且两者具有相同的指数位数)。这意味着两个 float32 数的乘积可以精确地表示为九个 bfloat16 数乘积的和。然后,我们可以通过丢弃其中一些乘积来权衡精度与速度。“high”精度算法特别只保留了三个最重要的乘积,这恰好排除了所有涉及任一输入最后 8 位尾数的乘积。这意味着我们可以将输入表示为两个 bfloat16 数的和,而不是三个。因为 bfloat16 熔合乘加 (FMA) 指令通常比 float32 指令快 10 倍以上,所以使用 bfloat16 精度进行三次乘法和 2 次加法比使用 float32 精度进行一次乘法要快。

Henry2019

http://arxiv.org/abs/1904.06376

注意

这不会改变 float32 矩阵乘法的输出 dtype,它控制矩阵乘法的内部计算是如何执行的。

注意

这不会改变卷积运算的精度。其他标志,如 torch.backends.cudnn.allow_tf32,可能会控制卷积运算的精度。

注意

此标志目前仅影响一种原生设备类型:CUDA。如果设置为“high”或“medium”,则在计算 float32 矩阵乘法时将使用 TensorFloat32 数据类型,这等同于将 torch.backends.cuda.matmul.allow_tf32 = True。当设置为“highest”(默认值)时,float32 数据类型用于内部计算,这等同于将 torch.backends.cuda.matmul.allow_tf32 = False

参数

precision (str) – 可以设置为“highest”(默认值)、“high”或“medium”(见上文)。