torch.compile
具有不同的 autograd 语义#
创建日期:2025 年 6 月 26 日 | 最后更新日期:2025 年 6 月 26 日
当您将 torch.compile
应用到模型前向传播的某个函数时,它会自动为编译后的函数生成一个反向传播。在编译期间,它会为反向传播跟踪出一个图,该图将在每次调用 autograd 时使用。我们将 torch.compile
内部负责此任务的组件称为 AOTDispatcher
(有时也称为 AOTAutograd
)。
因此,torch.compile
会在前向传播的函数编译过程中,将计算的细节“烘焙”到跟踪出的反向传播图中。然而,在 eager 模式的 PyTorch 中,反向传播是动态的:在前向传播之外,您可以将 tensor.backward()
或 torch.autograd.grad(...)
的调用包装在可能改变其行为的上下文管理器中。
此页面记录了 torch.compile
的 autograd 语义与 eager 模式 PyTorch 的不同之处,以及如何规避这些差异。
Autocast
行为#
torch.compile
会预先假设反向传播是否会在环境 autocast 上下文管理器下运行。默认情况下,使用 torch._functorch.config.backward_pass_autocast
来控制该假设;不正确的假设可能导致静默的错误。
选项包括:
"same_as_forward"
(默认)。我们假设torch.compile
编译区域的反向传播将在该区域运行的相同 autocast 上下文管理器下运行(如果存在)。如果您的代码如下所示,请使用此选项:with torch.amp.autocast(...): y = torch.compile(region)(x) ... # backward pass run under the same autocast context as the compiled region z.backward()
"off"
。我们假设torch.compile
编译区域的反向传播不会在任何 autocast 上下文管理器下运行。如果您的代码如下所示,请使用此选项:with torch.amp.autocast(...): y = torch.compile(region)(x) ... # Backward pass runs under no autocast. z.backward()
还有第三个选项。如果将
torch._functorch.config.backward_pass_autocast
设置为 kwargs 列表,我们将假定反向传播在由这些 kwargs 构建的 autocast 上下文下运行。例如,如果您的代码如下所示:
y = torch.compile(region)(x) ... # Backward pass runs under special context manager with torch.amp.autocast(**kwargs): z.backward()
则将
torch._functorch.config.backward_pass_autocast = kwargs
设置为。
使用 patch
将选项应用于特定的 torch.compile
调用。
with torch.amp.autocast(...):
with torch._functorch.config.patch(backward_pass_autocast="same_as_forward")
y = torch.compile(region)(x)
...
# backward pass run under the same autocast context as the compiled region
z.backward()