注意
跳到末尾 下载完整示例代码。
使用 torch.compile 构建卷积/批归一化融合器#
如何使用 torch.compile 的模式匹配器注册自定义融合模式
PyTorch v2.7.0
注意
此优化仅适用于推理模式下的模型(即 model.eval()
)。然而,torch.compile 的模式匹配系统适用于训练和推理。
首先,让我们导入一些包(我们稍后将在代码中用到所有这些包)。
from typing import Type, Dict, Any, Tuple, Iterable
import copy
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
在本教程中,我们将创建一个由卷积和批归一化组成的模型。请注意,这个模型有一些棘手的组件——一些卷积/批归一化模式隐藏在 Sequential 中,并且其中一个 BatchNorms
被另一个 Module 包裹。
class WrappedBatchNorm(nn.Module):
def __init__(self):
super().__init__()
self.mod = nn.BatchNorm2d(1)
def forward(self, x):
return self.mod(x)
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.bn1 = nn.BatchNorm2d(1)
self.conv2 = nn.Conv2d(1, 1, 1)
self.nested = nn.Sequential(
nn.BatchNorm2d(1),
nn.Conv2d(1, 1, 1),
)
self.wrapped = WrappedBatchNorm()
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.nested(x)
x = self.wrapped(x)
return x
model = M().to(device)
model.eval()
将卷积与批归一化融合#
尝试在 PyTorch 中自动融合卷积和批归一化的主要挑战之一是 PyTorch 没有提供一种简单的方法来访问计算图。torch.compile 通过在编译期间捕获计算图解决了这个问题,允许我们对整个模型应用基于模式的优化,包括嵌套在 Sequential 模块中或封装在自定义模块中的操作。
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import register_replacement
torch.compile 将捕获我们模型的图表示。在编译期间,隐藏在 Sequential 容器和封装模块中的模块都会被内联到图中,从而可用于模式匹配和优化。
将卷积与批归一化融合#
与其他一些融合不同,卷积与批归一化的融合不需要任何新的操作符。相反,由于推理期间的批归一化由逐点加法和乘法组成,这些操作可以“烘焙”到前一个卷积的权重中。这使我们能够完全从模型中移除批归一化!有关更多详细信息,请阅读 https://nenadmarkus.com/p/fusing-batchnorm-and-conv/。此处的代码复制自 pytorch/pytorch,以方便理解。
def fuse_conv_bn_eval(conv, bn):
"""
Given a conv Module `A` and an batch_norm module `B`, returns a conv
module `C` such that C(x) == B(A(x)) in inference mode.
"""
assert(not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
fused_conv.weight, fused_conv.bias = \
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
return fused_conv
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
bn_w = torch.ones_like(bn_rm)
if bn_b is None:
bn_b = torch.zeros_like(bn_rm)
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)
使用 torch.compile 进行模式匹配#
现在我们有了融合逻辑,我们需要注册一个模式,以便 torch.compile 的模式匹配器在编译期间识别并替换它。
# Define the pattern we want to match: conv2d followed by batch_norm
def conv_bn_pattern(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias):
conv_out = torch.nn.functional.conv2d(x, conv_weight, conv_bias)
bn_out = torch.nn.functional.batch_norm(
conv_out, bn_mean, bn_var, bn_weight, bn_bias,
training=False, eps=1e-5
)
return bn_out
def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias):
fused_weight, fused_bias = fuse_conv_bn_weights(
conv_weight, conv_bias, bn_mean, bn_var, 1e-5, bn_weight, bn_bias
)
return torch.nn.functional.conv2d(x, fused_weight, fused_bias)
# Example inputs are needed to trace the pattern functions.
# The inputs should match the function signatures of conv_bn_pattern and conv_bn_replacement.
# These are used to trace the pattern functions to create the match template.
# IMPORTANT: The pattern matcher is shape-agnostic! The specific shapes you use here
# don't limit what shapes will be matched - any valid conv2d->batch_norm sequence
# will be matched regardless of channels, kernel size, or spatial dimensions.
# - x: input tensor (batch_size, channels, height, width)
# - conv_weight: (out_channels, in_channels, kernel_h, kernel_w)
# - conv_bias: (out_channels,)
# - bn_mean, bn_var, bn_weight, bn_bias: all have shape (num_features,) matching out_channels
example_inputs = [
torch.randn(1, 1, 4, 4).to(device), # x: input tensor
torch.randn(1, 1, 1, 1).to(device), # conv_weight: 1 output channel, 1 input channel, 1x1 kernel
torch.randn(1).to(device), # conv_bias: 1 output channel
torch.randn(1).to(device), # bn_mean: batch norm running mean
torch.randn(1).to(device), # bn_var: batch norm running variance
torch.randn(1).to(device), # bn_weight: batch norm weight (gamma)
torch.randn(1).to(device), # bn_bias: batch norm bias (beta)
]
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._inductor import config
# Create a pattern matcher pass and register our pattern
patterns = PatternMatcherPass()
register_replacement(
conv_bn_pattern,
conv_bn_replacement,
example_inputs,
pm.fwd_only,
patterns,
)
# Create a custom pass function that applies our patterns
def conv_bn_fusion_pass(graph):
return patterns.apply(graph)
# Set our custom pass in the config
config.post_grad_custom_post_pass = conv_bn_fusion_pass
注意
为了演示目的,我们在这里做了一些简化,例如只匹配 2D 卷积。torch.compile 中的模式匹配器可以处理更复杂的模式。
测试我们的融合通道#
我们现在可以在我们的初始玩具模型上运行这个融合通道,并验证我们的结果是相同的。此外,我们可以打印出我们融合模型的代码,并验证不再有批归一化。
from torch._dynamo.utils import counters
# Clear the counters before compilation
counters.clear()
# Ensure pattern matcher is enabled
config.pattern_matcher = True
fused_model = torch.compile(model, backend="inductor")
inp = torch.randn(5, 1, 1, 1).to(device)
# Run the model to trigger compilation and pattern matching
with torch.no_grad():
output = fused_model(inp)
expected = model(inp)
torch.testing.assert_close(output, expected)
# Check how many patterns were matched
assert counters['inductor']['pattern_matcher_count'] == 3, "Expected 3 conv-bn patterns to be matched"
# Create a model with different shapes than our example_inputs
test_model_diff_shape = nn.Sequential(
nn.Conv2d(3, 16, 5),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, 7),
nn.BatchNorm2d(32),
).to(device).eval()
counters.clear()
compiled_diff_shape = torch.compile(test_model_diff_shape, backend="inductor")
test_input_diff_shape = torch.randn(1, 3, 28, 28).to(device)
with torch.no_grad():
compiled_diff_shape(test_input_diff_shape)
# Check how many patterns were matched
assert counters['inductor']['pattern_matcher_count'] == 2, "Expected 2 conv-bn patterns to be matched"
在 ResNet18 上对我们的融合进行基准测试#
我们可以在像 ResNet18 这样更大的模型上测试我们的融合通道,看看这个通道能提高多少推理性能。
import torchvision.models as models
import time
rn18 = models.resnet18().to(device)
rn18.eval()
inp = torch.randn(10, 3, 224, 224).to(device)
output = rn18(inp)
def benchmark(model, iters=20):
with torch.no_grad():
for _ in range(10):
model(inp)
begin = time.time()
for _ in range(iters):
model(inp)
return str(time.time()-begin)
# Benchmark original model
print("Original model time: ", benchmark(rn18))
# Compile with our custom pattern
compiled_with_pattern_matching = torch.compile(rn18, backend="inductor")
# Benchmark compiled model
print("\ntorch.compile (with conv-bn pattern matching and other fusions): ", benchmark(compiled_with_pattern_matching))
############
# Conclusion
# ----------
# As we can see, torch.compile provides a powerful way to implement
# graph transformations and optimizations through pattern matching.
# By registering custom patterns, we can extend torch.compile's
# optimization capabilities to handle domain-specific transformations.
#
# The conv-bn fusion demonstrated here is just one example of what's
# possible with torch.compile's pattern matching system.