评价此页

PyTorch 2.0 NNModule 支持#

创建于:2023年4月6日 | 最后更新于:2025年6月10日

作者Will Constable

torch.compile 对 torch.nn.Module 对象有特殊处理,它以不同于追踪任意 Python 类的方式追踪它们,目的是通过做出结构上的假设来生成更快的代码。

本文档描述了由于这种专门化而带来的一些权衡或边缘情况。

NNModule Hooks 支持#

以前,torch.compile 不支持 nn.Modules 上的 hooks,如果注册了 hooks,它们在编译后的程序中会被忽略。事实上,许多用户根本不使用 nn.Module hooks,或者只将其用于调试工作流,但将 nn.Module hooks 与 torch.compile 组合使用存在有效用例。

通过 nn.Module.__call__ 实现编排的 hooks 包括 _forward_pre_hooksforward_hooks_backward_pre_hooks_backward_hooks,它们将被称为“call hooks”。这些 hooks 由 torch.compile 部分支持,但存在一些限制。

另一类 hooks 包括 _state_dict_hooks 及其 preload_ 变体,它们仍然不受 torch.compile 支持。

nn.Module.__call__ Hooks 使用和限制#

默认情况下,torch.compile 会追踪 nn.Module.__call__ 的内容,这意味着它会遇到并运行 forward/pre-forward hooks。如果您在调用 torch.compile 之前安装了 hooks,并且之后不移除或更改 hooks,那么您的用例应该会得到默认支持。

Backward/Pre-backward hooks 通常也得到支持,但有类似的注意事项:目前在 dynamo 中访问 backward_hooks 字典会导致图中断,这可能通过一些工作来避免。图中断也会影响 backward hooks 的触发时机,因为图片段会作为 autograd-functions 运行,它们会同时产生所有 grad。假设 dynamo 能够避免在存在 backward-hooks 时发生图中断,我们仍然期望一系列模块的 backward hooks 在整个编译图的 backward 运行后一起触发。

“允许的模块”上的 hooks torch.compile 特别对待常见的模块,如 torch.conv,以及难以追踪的模块,方法是允许它们在 dynamo 图中被不透明地调用,而不是被 dynamo 追踪。对于这些模块,hooks 目前会触发图中断,以便受影响的模块在 dynamo 之外运行。根据模型的不同,这可能会导致显著的性能下降,并且需要额外的工作来改进此支持。

skip_nnmodule_hook_guards 默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards 设置为 True,这意味着不会在每个 nn.Module hook 字典上安装 guards,从而减少 guard 执行时间以提高运行时性能,但代价是无法发现 hook 字典在编译后是否被更改。

如果您希望在编译后能够移除或修改 hooks,并让 torch.compile 做出适当的反应(通过重新编译),则需要将 skip_nnmodule_hook_guards=False,并为增加的 guards 预期运行时开销。

TODO:确认 backward/pre_backward hooks 是否工作,并据此进行文档记录

state_dict Hooks#

torch.compile 尚未支持 state dict hooks。

TODO:当存在 hooks 时,发出一次警告。当存在 hooks 时,发出一次警告以指向此文档。