评价此页

MKLDNN 后端#

创建日期:2025 年 5 月 10 日 | 最后更新日期:2025 年 7 月 17 日

MKLDNN 是一个开源的跨平台性能库,包含深度学习应用程序的基本构建块。

# The flag below controls whether enable MKLDNN backend in Pytorch.
torch.backends.mkldnn.enabled = True

用户可以通过以下方式禁用 MKLDNN 后端:

torch.backends.mkldnn.enabled = False

MKLDNN 后端的 Bfloat16 (BF16)#

从 PyTorch 2.9 开始,提供了一组 API 来控制 float32 运算符的内部计算精度。

# The flag below controls the internal computation precision for mkldnn matmul. Default ieee is float32.
torch.backends.mkldnn.matmul.fp32_precision = "ieee"

# The flag below controls the internal computation precision for mkldnn conv. Default ieee is float32.
torch.backends.mkldnn.conv.fp32_precision = "ieee"

# The flag below controls the internal computation precision for mkldnn rnn. Default ieee is float32.
torch.backends.mkldnn.rnn.fp32_precision = "ieee"

请注意,除了 matmuls 和 convolutions 本身之外,内部使用 matmuls 或 convolutions 的函数和 nn 模块也会受到影响。这些包括 torch.nn.Lineartorch.nn._ConvNdtorch.cdist()torch.tensordot()torch.nn.functional.affine_grid()torch.nn.functional.grid_sample()torch.nn.AdaptiveLogSoftmaxWithLosstorch.nn.GRUtorch.nn.LSTM

为了了解精度和速度,请参见下面的示例代码和基准测试数据(在 SPR 上)。

torch.manual_seed(0)
a_full = torch.randn(10240, 10240, dtype=torch.double)
b_full = torch.randn(10240, 10240, dtype=torch.double)
ab_full = a_full @ b_full
mean = ab_full.abs().mean()  # 80.7451

a = a_full.float()
b = b_full.float()

# Do matmul at BF16 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'bf16'
ab_bf16 = a @ b  # expected speedup with BF16 dot-product acceleration
error = (ab_bf16 - ab_full).abs().max()  # 1.3704
relative_error = error / mean  # 0.0170
print(error, relative_error)

# Do matmul at TF32 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'tf32'
ab_tf32 = a @ b  # expected speedup with TF32 dot-product acceleration
error = (ab_tf32 - ab_full).abs().max()  # 0.0004
relative_error = error / mean  # 0.00000552
print(error, relative_error)

# Do matmul FP32 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
ab_fp32 = a @ b
error = (ab_fp32 - ab_full).abs().max()  # 0.0003
relative_error = error / mean  # 0.00000317
print(error, relative_error)

从上面的示例可以看出,使用 BF16,在 SPR 上的速度提高了约 7 倍,与双精度相比,相对误差大约大 2 个数量级。如果需要完整的 FP32 精度,用户可以通过以下方式禁用 BF16:

torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
torch.backends.mkldnn.conv.fp32_precision = 'ieee'
torch.backends.mkldnn.rnn.fp32_precision = 'ieee'

要禁用 C++ 中的 BF16 标志,您可以执行以下操作:

at::globalContext().setFloat32Precision("ieee", "mkldnn", "matmul");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "conv");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "rnn");

如果 fp32_precision 设置为 ieee,我们可以为特定的运算符或后端覆盖通用设置。

torch.backends.fp32_precision = "bf16"
torch.backends.mkldnn.fp32_precision = "ieee"
torch.backends.mkldnn.matmul.fp32_precision = "ieee"

在这种情况下,torch.backends.mkldnn.fp32_precisiontorch.backends.mkldnn.matmul.fp32_precision 都将被覆盖为 bf16。