torch.set_float32_matmul_precision#
- torch.set_float32_matmul_precision(precision)[source]#
设置 float32 矩阵乘法的内部精度。
以较低精度运行 float32 矩阵乘法可以显著提高性能,并且在某些程序中精度损失的影响可以忽略不计。
支持三种设置:
“highest”,float32 矩阵乘法在内部计算中使用 float32 数据类型(24 个尾数位,23 位显式存储)。
“high”,float32 矩阵乘法使用 TensorFloat32 数据类型(10 个尾数位显式存储),或者将每个 float32 数字视为两个 bfloat16 数字的和(约 16 个尾数位,14 位显式存储),如果存在合适的快速矩阵乘法算法。否则,float32 矩阵乘法将按“highest”精度计算。有关 bfloat16 方法的更多信息,请参见下文。
“medium”,float32 矩阵乘法在内部计算中使用 bfloat16 数据类型(8 个尾数位,7 位显式存储),如果存在使用该数据类型进行内部计算的快速矩阵乘法算法。否则,float32 矩阵乘法将按“high”精度计算。
使用“high”精度时,float32 乘法可能使用一种基于 bfloat16 的算法,该算法比简单地截断到某个较小的尾数位数(例如,TensorFloat32 为 10 位,bfloat16 为 7 位显式存储)更为复杂。有关此算法的完整描述,请参阅 [Henry2019]。在此简要说明,第一步是认识到我们可以将一个 float32 数字完美地编码为三个 bfloat16 数字的和(因为 float32 有 23 个尾数位,而 bfloat16 有 7 个显式存储位,并且两者具有相同的指数位数)。这意味着两个 float32 数字的乘积可以精确地给出为九个 bfloat16 数字乘积的和。然后,我们可以通过丢弃其中一些乘积来在准确性和速度之间进行权衡。“high”精度算法特别只保留三个最显著的乘积,这恰好排除了涉及任一输入最后 8 个尾数位的乘积。这意味着我们可以将输入表示为两个 bfloat16 数字的和,而不是三个。因为 bfloat16 融合乘加 (FMA) 指令通常比 float32 指令快 10 倍以上,所以使用 bfloat16 精度进行三次乘法和 2 次加法比使用 float32 精度进行一次乘法要快。
注意
这不会改变 float32 矩阵乘法的输出 dtype,它控制着矩阵乘法内部的计算方式。
注意
这不会改变卷积操作的精度。其他标志,如 torch.backends.cudnn.allow_tf32,可能会控制卷积操作的精度。
注意
此标志目前仅影响一种本机设备类型:CUDA。如果设置为“high”或“medium”,则在计算 float32 矩阵乘法时将使用 TensorFloat32 数据类型,这等同于设置 torch.backends.cuda.matmul.allow_tf32 = True。当设置为“highest”(默认值)时,float32 数据类型将用于内部计算,这等同于设置 torch.backends.cuda.matmul.allow_tf32 = False。
- 参数
precision (str) – 可以设置为“highest”(默认值)、“high”或“medium”(见上文)。