PyTorch 2.0 NNModule 支持#
创建于: 2023年4月06日 | 最后更新于: 2025年6月10日
作者: Will Constable
torch.compile
对 torch.nn.Module 对象有特殊处理,它会以不同于追踪任意 Python 类的方式来追踪它们,目的是通过做出关于结构的假设来生成更快的代码。
本文档描述了由于这种专业化而产生的一些权衡或边缘情况。
NNModule Hook 支持#
之前,torch.compile
不支持 nn.Modules 上的 hooks,如果注册了 hooks,它们在编译后的程序中将被忽略。事实上,许多用户根本不使用 nn.Module hooks,或者只在调试工作流中使用它们,但存在将 nn.Module hooks 与 torch.compile
结合使用的有效用例。
通过 nn.Module.call 实现编排的 Hooks 包括 _forward_pre_hooks
、forward_hooks
、_backward_pre_hooks
和 _backward_hooks
,并将被引用为“call hooks”。这些 hooks 在 torch.compile
中得到部分支持,但存在以下限制。
另一类 Hooks 包括 _state_dict_hooks
及其 pre
和 load_
变体,它们仍然不受 torch.compile
支持。
nn.Module.__call__
Hooks 用法和限制#
默认情况下,torch.compile
会追踪 nn.Module.__call__
的内容,这意味着它会遇到并运行前向/预前向 hooks。如果您在调用 torch.compile
之前注册了 hooks,并且之后不移除或更改 hooks,那么您的用例应该得到默认支持。
后向/预后向 hooks 通常也得到支持,但有类似的注意事项:目前在 dynamo 中访问 backward_hooks 字典时会发生图中断 (graph-breaks),这可能通过一些工作来避免。图中断也会影响后向 hooks 的触发时机,因为图段被作为 autograd-functions 运行,它们会同时产生所有梯度。假设 dynamo 可以避免因存在后向 hooks 而导致图中断,我们仍然期望一系列模块的后向 hooks 在整个编译图的后向运行后一起触发。
“允许模块”上的 hooks torch.compile
特别处理常见的模块,如 torch.conv,以及难以追踪的模块,允许它们在 dynamo 图中被不透明地调用,而不是被 dynamo 追踪。对于这类模块,hooks 当前会触发图中断,以便受影响的模块在 dynamo 外部运行。根据模型,这可能会导致显著的性能下降,并且需要额外的工作来改进此支持。
skip_nnmodule_hook_guards 默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards
设置为 True,这意味着不会在每个 nn.Module hook 字典上安装 guards,从而通过减少 guard 执行时间来提高运行时性能,但代价是无法在编译后发现任何 hook 字典被更改。
如果您希望在编译后能够移除或修改 hooks,并且让 torch.compile
做出适当的反应(通过重新编译),那么您需要将 skip_nnmodule_hook_guards=False
,并预期由于增加了 guards 而产生的运行时开销。
TODO:确认后向/预后向 hooks 是否工作,并相应地记录。
state_dict Hooks#
State dict hooks 尚未在 torch.compile
中得到支持。
TODO:如果 hook 触发图中断,则发出一次警告。如果存在 hook,则发出一次警告以指向本文档。