评价此页

PyTorch 2.0 NNModule 支持#

创建于: 2023年4月06日 | 最后更新于: 2025年6月10日

作者: Will Constable

torch.compile 对 torch.nn.Module 对象有特殊处理,它会以不同于追踪任意 Python 类的方式来追踪它们,目的是通过做出关于结构的假设来生成更快的代码。

本文档描述了由于这种专业化而产生的一些权衡或边缘情况。

NNModule Hook 支持#

之前,torch.compile 不支持 nn.Modules 上的 hooks,如果注册了 hooks,它们在编译后的程序中将被忽略。事实上,许多用户根本不使用 nn.Module hooks,或者只在调试工作流中使用它们,但存在将 nn.Module hooks 与 torch.compile 结合使用的有效用例。

通过 nn.Module.call 实现编排的 Hooks 包括 _forward_pre_hooksforward_hooks_backward_pre_hooks_backward_hooks,并将被引用为“call hooks”。这些 hooks 在 torch.compile 中得到部分支持,但存在以下限制。

另一类 Hooks 包括 _state_dict_hooks 及其 preload_ 变体,它们仍然不受 torch.compile 支持。

nn.Module.__call__ Hooks 用法和限制#

默认情况下,torch.compile 会追踪 nn.Module.__call__ 的内容,这意味着它会遇到并运行前向/预前向 hooks。如果您在调用 torch.compile 之前注册了 hooks,并且之后不移除或更改 hooks,那么您的用例应该得到默认支持。

后向/预后向 hooks 通常也得到支持,但有类似的注意事项:目前在 dynamo 中访问 backward_hooks 字典时会发生图中断 (graph-breaks),这可能通过一些工作来避免。图中断也会影响后向 hooks 的触发时机,因为图段被作为 autograd-functions 运行,它们会同时产生所有梯度。假设 dynamo 可以避免因存在后向 hooks 而导致图中断,我们仍然期望一系列模块的后向 hooks 在整个编译图的后向运行后一起触发。

“允许模块”上的 hooks torch.compile 特别处理常见的模块,如 torch.conv,以及难以追踪的模块,允许它们在 dynamo 图中被不透明地调用,而不是被 dynamo 追踪。对于这类模块,hooks 当前会触发图中断,以便受影响的模块在 dynamo 外部运行。根据模型,这可能会导致显著的性能下降,并且需要额外的工作来改进此支持。

skip_nnmodule_hook_guards 默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards 设置为 True,这意味着不会在每个 nn.Module hook 字典上安装 guards,从而通过减少 guard 执行时间来提高运行时性能,但代价是无法在编译后发现任何 hook 字典被更改。

如果您希望在编译后能够移除或修改 hooks,并且让 torch.compile 做出适当的反应(通过重新编译),那么您需要将 skip_nnmodule_hook_guards=False,并预期由于增加了 guards 而产生的运行时开销。

TODO:确认后向/预后向 hooks 是否工作,并相应地记录。

state_dict Hooks#

State dict hooks 尚未在 torch.compile 中得到支持。

TODO:如果 hook 触发图中断,则发出一次警告。如果存在 hook,则发出一次警告以指向本文档。