• 文档 >
  • 控制 MXU 浮点精度
快捷方式

控制 MXU 浮点精度

作者: 何耀祥

创建日期 2025/05/15

最后修改日期 2025/05/15

在本教程中,您将学习在使用 PyTorch/XLA 的某些加速器(如 TPU)时,如何控制矩阵乘法(mat mul)操作的浮点精度。您还将学习如何访问 torch 的浮点信息以及如何直观地检查数字的浮点表示。

介绍

Google TPU 内置了在物理模块(称为矩阵乘法单元或 MXU)上优化的矩阵乘法。为了保持速度,研究人员发现了一种低成本的权衡。研究表明,神经网络可以以低于 FP32 的精度进行训练,“而不会对模型精度产生任何明显影响”。但范围情况并非如此。由于范数等操作,保留 FP32 的范围很重要。解决方案是 bfloat16:它具有与 FP32 相同的范围,但精度较低。

Nvidia V100 及更新的 GPU 还包含称为 TensorCores 的专用矩阵乘法单元。这些 GPU 使用一种称为 TF32 的数值格式,它具有与 FP32 和 bfloat16 相同的范围,但精度居中(10 位尾数),因为 TF32 总共只有 19 位。

在 FP32 值上执行的矩阵乘法操作将为 TPU 生成 bfloat16 结果,为 Nvidia GPU 生成 TF32(19 位)结果。

bits layout

低精度硬件上的高精度计算

即使 bfloat16 只有 7 位尾数,也可以进行更高精度的计算。这是通过将数字分解为其组件来完成的。为了建立直观理解,请想象一个 MXU 在十进制(小数)数字系统中支持 2 位数字。目标是乘以具有 4 位精度的数字,例如 9.111 和 9.222。在无限精度下,乘积为 84.021642。请注意,具有 4 位精度的两个数字会生成结果中两倍多的数字精度。但鉴于数字格式是 4 位,结果将被舍入为 84.02。

最简单的方法是将数字舍入为 9.1 和 9.2,结果为 83.72。这在概念上是 PyTorch/XLA 在 TPU 上的“默认”精度设置。

下一个方法是将每个数字分解为两部分,高位和低位(H 和 L):$(9.1 + 0.011) \times (9.2 + 0.022)$。这等于 $(H \times H + H \times L + L \times H + L \times L)$。前三个矩阵乘法构成了三趟方法,并将有效精度大致翻倍。第四项 $L \times L$ 被忽略,查看结果 $0.000242$,很容易看出该值不会对最终结果做出贡献。某些 $L \times L$ 值可能会生成第四项,将值改变一位,但每隔一段时间添加一位信息,相对于运行另一个乘法的成本,价值不大。

                  +--------+--------+
                  | 9.222           |
                  +--------+--------+
                  | 9.2    | 0.022  |
+--------+--------+--------+--------+
|9.111   | 9.1    |83.72   | 0.2002 |
+--------+--------+--------+--------+
|        | 0.011  | 0.1012 |        | = 84.0214 => 84.02
+--------+--------+--------+--------+

再次扩展此方法将产生大约三倍的精度。这个想法是将数字分解为高位、中位和低位(H、M 和 L),生成九个可能的项:$(H + M + L) \times (H + M + L) = HH + HM + MH + MM + HL + LH + ML + LM + LL$。最后三项被忽略,前六项构成了六趟方法。它基本上等同于 FP32,并为次要位的方差留有余地。

PyTorch/XLA 和 TPU

PyTorch/XLA 允许在 torch_xla.backends.set_mat_mul_precision() 函数中控制一趟、三趟和六趟方法。有效值为 defaulthighhighest。现在,您将调查这三个设置之间的差异。

警告:虽然本笔记本演示了多次设置精度,但建议只在脚本开头设置一次精度。

准备工作

确保您在 TPU 上运行此教程。您可以使用 Google Colab 访问 TPU。

导入所需的包。

import torch
import torch_xla.backends

torch.set_printoptions(precision=20, sci_mode=False, linewidth=240)
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.

Epsilon 是 1.0 和下一个可表示数之间的最小差值。从 torch 中检索该值。

eps = torch.finfo(torch.bfloat16).eps
print(f"bfloat16 epsilon: {eps}")
print(f"return type of torch.finfo: {type(eps)}")
bfloat16 epsilon: 0.0078125
return type of torch.finfo: <class 'float'>

Epsilon 也定义为 1 / 2^p,其中 p 是尾数的位数。

print(1 / (2**7))
0.0078125

在此之间的数字可能会向上舍入为 1.0 + epsilon,或向下舍入为 1.0。

print(
    torch.tensor(
        [1.0, 1.0 + eps / 4.0, 1.0 + eps / 2, 1.0 + eps * 3 / 4, 1.0 + eps],
        dtype=torch.bfloat16,
    ))
tensor([1.00000000000000000000, 1.00000000000000000000, 1.00000000000000000000, 1.00781250000000000000, 1.00781250000000000000], dtype=torch.bfloat16)

准备直接查看位

设置将二进制字符串转换为 FP32 数字的工具,反之亦然。创建一个函数来生成随机矩阵。

通常,在测试 MXU(或 TensorCore)时,传递矩阵以促使 XLA 使用 MXU,而不是更慢但更精确的 FP32 单位。

import struct


def binary_fraction_to_fp32(bstr: str) -> float:
  if bstr[:4] != "0b1.":
    raise ValueError(f"Invalid binary string: {bstr}")
  fraction_bits = bstr[4:]
  mantissa = 1.0
  for i, bit in enumerate(fraction_bits):
    mantissa += int(bit) * 2**-(i + 1)
  return float(mantissa)


def fp32_to_binary_fraction(fp32_float: float) -> str:
  x_bytes = struct.pack(">f", fp32_float)  # Big-endian IEEE 754 float32
  as_int = struct.unpack(">I", x_bytes)[0]  # Interpret bits as uint32
  sign = (as_int >> 31) & 0b1
  exponent = (as_int >> 23) & 0xFF
  mantissa = as_int & 0x7FFFFF  # lower 23 bits
  return f"FORMAT:0b SIGN:{sign} EXPONENT:{exponent:08b} MANTISSA:{mantissa:023b} VALUE={fp32_float}"


def get_rand_matrix():
  """Returns a diagonal matrix of shape 1024, 1024, values between 0.999 and 1.111"""
  eye = torch.eye(1024, dtype=torch.float32, device="xla")
  rand_ = torch.rand(
      (1024, 1024), dtype=torch.float32, device="xla") * 0.2 + 0.9
  result = eye * rand_
  assert torch.nonzero(result).size(0) == 1024, torch.nonzero(result).size(0)
  return result

检查数字

生成一个表示 $1 + \text{bf16\_eps}/2$ 的 FP32 数字。这将使 bfloat16 尾数超出第八位精度。

one_plus_half_eps = binary_fraction_to_fp32("0b1." + "0" * 7 + "1" + "0" * 15)
print(f"FP32     : {one_plus_half_eps }")
print(f"1 + eps/2: {1.0 + eps / 2}")
FP32     : 1.00390625
1 + eps/2: 1.00390625

打印 FP32 和 BF16 的位数。请注意,第 8 位丢失了。这再次证实了 BF16 无法表示第 8 位精度。

print(f"FP32: {fp32_to_binary_fraction(one_plus_half_eps)}")
ones_bf16 = torch.tensor(
    one_plus_half_eps, dtype=torch.bfloat16).to(torch.float32).item()
print(f"BF16: {fp32_to_binary_fraction(ones_bf16)}")
FP32: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000001000000000000000 VALUE=1.00390625
BF16: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000000000000000000000 VALUE=1.0

MXU

将感兴趣的数字放在对角矩阵中。通过将它们放在矩阵中,XLA 将在 MXU 上执行计算。通过使矩阵对角化,计算将等同于逐元素乘法。

请注意,这些值在相乘之前基本上被舍入为 1.0,导致输出为 1.0。这是 TPU 中发生的精度损失。

X = get_rand_matrix()
Y = get_rand_matrix()
X[0, 0] = one_plus_half_eps
Y[0, 0] = one_plus_half_eps
Z = torch.matmul(X, Y)
print(f"X: {fp32_to_binary_fraction(X[0][0].item())}")
print(f"Y: {fp32_to_binary_fraction(Y[0][0].item())}")
print(f"Z: {fp32_to_binary_fraction(Z[0][0].item())}")
X: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000001000000000000000 VALUE=1.00390625
Y: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000001000000000000000 VALUE=1.00390625
Z: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000000000000000000000 VALUE=1.0

bfloat16 硬件上的 FP32 精度

三趟和六趟方法生成了更多的精度位数。开启最高精度模式(六趟)并再次运行实验。请注意,TPU 已计算出 FP32 精度。

Z_ref = torch.matmul(
    X.to("cpu").to(torch.float32),
    Y.to("cpu").to(torch.float32))
print(f"Z_ref: {fp32_to_binary_fraction(Z_ref[0][0].item())}")
torch_xla.backends.set_mat_mul_precision("highest")
Z = torch.matmul(X, Y)
print(f"Z:     {fp32_to_binary_fraction(Z[0][0].item())}")
WARNING:torch_xla.backends:Setting mat mul precision multiple times is not recommended. If you need to do so, please empirically verify that the precision setting is behaving as expected.
Z_ref: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000010000000010000000 VALUE=1.0078277587890625
Z:     FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000010000000010000000 VALUE=1.0078277587890625

边缘情况数字

在前面的示例中,您没有看到六趟和 FP32 乘法之间的区别。现在,您将使用一个边缘情况数字来演示六趟方法和完整 FP32 之间的最终位差异。

X = get_rand_matrix()
Y = get_rand_matrix()
X[0, 0] = 1.2
Y[0, 0] = 1.2
Z_ref = torch.matmul(
    X.to("cpu").to(torch.float32),
    Y.to("cpu").to(torch.float32))
print(f"Z_ref: {fp32_to_binary_fraction(Z_ref[0][0].item())}")
torch_xla.backends.set_mat_mul_precision("highest")
Z = torch.matmul(X, Y)
print(f"Z:     {fp32_to_binary_fraction(Z[0][0].item())}")
WARNING:torch_xla.backends:Setting mat mul precision multiple times is not recommended. If you need to do so, please empirically verify that the precision setting is behaving as expected.
Z_ref: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:01110000101000111101100 VALUE=1.440000057220459
Z:     FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:01110000101000111101101 VALUE=1.4400001764297485

结论

在本教程中,您学习了如何控制矩阵乘法(mat mul)操作的浮点精度。您还学习了用于通过三趟和六趟方法生成更高精度的内部算法。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源