MKLDNN 后端#
创建日期:2025 年 5 月 10 日 | 最后更新日期:2025 年 7 月 17 日
MKLDNN 是一个开源的跨平台性能库,用于深度学习应用程序的基本构建块。
# The flag below controls whether enable MKLDNN backend in Pytorch.
torch.backends.mkldnn.enabled = True
用户可以通过以下方式禁用 MKLDNN 后端
torch.backends.mkldnn.enabled = False
MKLDNN 后端的 Bfloat16 (BF16)#
从 PyTorch 2.9 开始,有一组 API 用于控制 float32 算子的内部计算精度。
# The flag below controls the internal computation precision for mkldnn matmul. Default ieee is float32.
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
# The flag below controls the internal computation precision for mkldnn conv. Default ieee is float32.
torch.backends.mkldnn.conv.fp32_precision = "ieee"
# The flag below controls the internal computation precision for mkldnn rnn. Default ieee is float32.
torch.backends.mkldnn.rnn.fp32_precision = "ieee"
请注意,除了矩阵乘法和卷积本身之外,内部使用矩阵乘法或卷积的函数和 nn 模块也会受到影响。这些包括 torch.nn.Linear、torch.nn._ConvNd、torch.cdist()、torch.tensordot()、torch.nn.functional.affine_grid() 和 torch.nn.functional.grid_sample()、torch.nn.AdaptiveLogSoftmaxWithLoss、torch.nn.GRU 和 torch.nn.LSTM。
为了了解精度和速度,请参阅下面的示例代码和基准测试数据(在 SPR 上)
torch.manual_seed(0)
a_full = torch.randn(10240, 10240, dtype=torch.double)
b_full = torch.randn(10240, 10240, dtype=torch.double)
ab_full = a_full @ b_full
mean = ab_full.abs().mean() # 80.7451
a = a_full.float()
b = b_full.float()
# Do matmul at BF16 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'bf16'
ab_bf16 = a @ b # expected speedup with BF16 dot-product acceleration
error = (ab_bf16 - ab_full).abs().max() # 1.3704
relative_error = error / mean # 0.0170
print(error, relative_error)
# Do matmul at TF32 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'tf32'
ab_tf32 = a @ b # expected speedup with TF32 dot-product acceleration
error = (ab_tf32 - ab_full).abs().max() # 0.0004
relative_error = error / mean # 0.00000552
print(error, relative_error)
# Do matmul FP32 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
ab_fp32 = a @ b
error = (ab_fp32 - ab_full).abs().max() # 0.0003
relative_error = error / mean # 0.00000317
print(error, relative_error)
从上面的示例可以看出,使用 BF16 时,SPR 上的速度大约快 7 倍,并且相对于双精度而言,相对误差大约大两个数量级。如果需要完整的 FP32 精度,用户可以通过以下方式禁用 BF16
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
torch.backends.mkldnn.conv.fp32_precision = 'ieee'
torch.backends.mkldnn.rnn.fp32_precision = 'ieee'
要在 C++ 中关闭 BF16 标志,您可以执行以下操作
at::globalContext().setFloat32Precision("ieee", "mkldnn", "matmul");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "conv");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "rnn");
如果 fp32_precision 设置为 ieee,我们可以覆盖特定算子或后端的通用设置。
torch.backends.fp32_precision = "bf16"
torch.backends.mkldnn.fp32_precision = "ieee"
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
在这种情况下,torch.backends.mkldnn.fp32_precision 和 torch.backends.mkldnn.matmul.fp32_precision 都会被覆盖为 bf16。