PyTorch 2.0 NNModule 支持#
创建于:2023年4月6日 | 最后更新于:2025年6月10日
torch.compile
对 torch.nn.Module 对象有特殊处理,它以不同于追踪任意 Python 类的方式追踪它们,目的是通过做出结构上的假设来生成更快的代码。
本文档描述了由于这种专门化而带来的一些权衡或边缘情况。
NNModule Hooks 支持#
以前,torch.compile
不支持 nn.Modules 上的 hooks,如果注册了 hooks,它们在编译后的程序中会被忽略。事实上,许多用户根本不使用 nn.Module hooks,或者只将其用于调试工作流,但将 nn.Module hooks 与 torch.compile
组合使用存在有效用例。
通过 nn.Module.__call__ 实现编排的 hooks 包括 _forward_pre_hooks
、forward_hooks
、_backward_pre_hooks
和 _backward_hooks
,它们将被称为“call hooks”。这些 hooks 由 torch.compile
部分支持,但存在一些限制。
另一类 hooks 包括 _state_dict_hooks
及其 pre
和 load_
变体,它们仍然不受 torch.compile
支持。
nn.Module.__call__
Hooks 使用和限制#
默认情况下,torch.compile
会追踪 nn.Module.__call__
的内容,这意味着它会遇到并运行 forward/pre-forward hooks。如果您在调用 torch.compile
之前安装了 hooks,并且之后不移除或更改 hooks,那么您的用例应该会得到默认支持。
Backward/Pre-backward hooks 通常也得到支持,但有类似的注意事项:目前在 dynamo 中访问 backward_hooks 字典会导致图中断,这可能通过一些工作来避免。图中断也会影响 backward hooks 的触发时机,因为图片段会作为 autograd-functions 运行,它们会同时产生所有 grad。假设 dynamo 能够避免在存在 backward-hooks 时发生图中断,我们仍然期望一系列模块的 backward hooks 在整个编译图的 backward 运行后一起触发。
“允许的模块”上的 hooks torch.compile
特别对待常见的模块,如 torch.conv,以及难以追踪的模块,方法是允许它们在 dynamo 图中被不透明地调用,而不是被 dynamo 追踪。对于这些模块,hooks 目前会触发图中断,以便受影响的模块在 dynamo 之外运行。根据模型的不同,这可能会导致显著的性能下降,并且需要额外的工作来改进此支持。
skip_nnmodule_hook_guards 默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards
设置为 True,这意味着不会在每个 nn.Module hook 字典上安装 guards,从而减少 guard 执行时间以提高运行时性能,但代价是无法发现 hook 字典在编译后是否被更改。
如果您希望在编译后能够移除或修改 hooks,并让 torch.compile
做出适当的反应(通过重新编译),则需要将 skip_nnmodule_hook_guards=False
,并为增加的 guards 预期运行时开销。
TODO:确认 backward/pre_backward hooks 是否工作,并据此进行文档记录
state_dict Hooks#
torch.compile
尚未支持 state dict hooks。
TODO:当存在 hooks 时,发出一次警告。当存在 hooks 时,发出一次警告以指向此文档。