评价此页

Patching Batch Norm#

创建于:2023年01月03日 | 最后更新于:2025年06月11日

发生了什么?#

Batch Norm 需要对 running_mean 和 running_var 进行原地更新,其大小与输入相同。Functorch 不支持对接受批处理张量的常规张量进行原地更新(即不允许 regular.add_(batched))。因此,当对单个模块的输入批次进行 vmap 时,我们会遇到此错误。

如何解决#

最受支持的方法之一是将 BatchNorm 切换为 GroupNorm。选项 1 和 2 支持这一点。

所有这些选项都假设您不需要 running stats。如果您正在使用一个模块,这意味着假设您不会在评估模式下使用 batch norm。如果您有在评估模式下使用 running batch norm 和 vmap 的用例,请提交一个 issue。

选项 1:更改 BatchNorm#

如果您想更改为 GroupNorm,请将所有 BatchNorm 替换为:

BatchNorm2d(C, G, track_running_stats=False)

这里的 C 与原始 BatchNorm 中的 C 相同。G 是将 C 分割成的组数。因此,C % G == 0,作为回退,您可以将 C == G,这意味着每个通道将单独处理。

如果您必须使用 BatchNorm 并且您自己构建了模块,您可以更改模块以不使用 running stats。换句话说,在任何有 BatchNorm 模块的地方,将 track_running_stats 标志设置为 False。

BatchNorm2d(64, track_running_stats=False)

选项 2:torchvision 参数#

一些 torchvision 模型,如 resnet 和 regnet,可以接受 norm_layer 参数。这些参数通常默认为 BatchNorm2d,如果它们被默认设置的话。

相反,您可以将其设置为 GroupNorm。

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))

这里,再次强调,c % g == 0,所以作为回退,请将 g = c

如果您一定要使用 BatchNorm,请确保使用不使用 running stats 的版本。

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

选项 3:functorch 的 patching#

functorch 添加了一些功能,允许快速、原地修改模块以不使用 running stats。更改 norm 层更易出错,因此我们未提供此选项。如果您有一个网络,并且希望 BatchNorm 不使用 running stats,您可以运行 replace_all_batch_norm_modules_ 以原地修改模块,使其不使用 running stats。

from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

选项 4:eval 模式#

在 eval 模式下运行时,running_mean 和 running_var 不会更新。因此,vmap 可以支持此模式。

model.eval()
vmap(model)(x)
model.train()