fuse_modules#
- class torch.ao.quantization.fuse_modules.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>, fuse_custom_config_dict=None)[source]#
将模块列表融合为单个模块。
仅融合以下模块序列:conv, bn;conv, bn, relu;conv, relu;linear, relu;bn, relu。所有其他序列保持不变。对于这些序列,会将列表中的第一项替换为融合后的模块,并将列表中的其余模块替换为 identity(恒等映射)。
- 参数:
model – 包含待融合模块的模型
modules_to_fuse – 待融合的模块名称列表的列表。如果只有一个待融合的模块列表,也可以直接传入字符串列表。
inplace – 布尔值,指定是否在原模型上进行原地(inplace)融合;默认情况下返回一个新的模型。
fuser_func – 接收一个模块列表并输出相同长度的融合后模块列表的函数。例如,fuser_func([convModule, BNModule]) 返回列表 [ConvBNModule, nn.Identity()]。默认为 torch.ao.quantization.fuse_known_modules
fuse_custom_config_dict – 融合的自定义配置
# Example of fuse_custom_config_dict fuse_custom_config_dict = { # Additional fuser_method mapping "additional_fuser_method_mapping": { (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn }, }
- 返回:
融合后的模型。如果 inplace=True,则会创建一个新的副本。
示例
>>> m = M().eval() >>> # m is a module containing the sub-modules below >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input) >>> m = M().eval() >>> # Alternately provide a single list of modules to fuse >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input)