torch.compile 具有不同的 autograd 语义#
创建日期:2025 年 6 月 26 日 | 最后更新日期:2025 年 6 月 26 日
当您将 torch.compile 应用到模型前向传播的某个函数时,它会自动为编译后的函数生成一个反向传播。在编译期间,它会为反向传播跟踪出一个图,该图将在每次调用 autograd 时使用。我们将 torch.compile 内部负责此任务的组件称为 AOTDispatcher(有时也称为 AOTAutograd)。
因此,torch.compile 会在前向传播的函数编译过程中,将计算的细节“烘焙”到跟踪出的反向传播图中。然而,在 eager 模式的 PyTorch 中,反向传播是动态的:在前向传播之外,您可以将 tensor.backward() 或 torch.autograd.grad(...) 的调用包装在可能改变其行为的上下文管理器中。
此页面记录了 torch.compile 的 autograd 语义与 eager 模式 PyTorch 的不同之处,以及如何规避这些差异。
Autocast 行为#
torch.compile 会预先假设反向传播是否会在环境 autocast 上下文管理器下运行。默认情况下,使用 torch._functorch.config.backward_pass_autocast 来控制该假设;不正确的假设可能导致静默的错误。
选项包括:
"same_as_forward"(默认)。我们假设torch.compile编译区域的反向传播将在该区域运行的相同 autocast 上下文管理器下运行(如果存在)。如果您的代码如下所示,请使用此选项:with torch.amp.autocast(...): y = torch.compile(region)(x) ... # backward pass run under the same autocast context as the compiled region z.backward()
"off"。我们假设torch.compile编译区域的反向传播不会在任何 autocast 上下文管理器下运行。如果您的代码如下所示,请使用此选项:with torch.amp.autocast(...): y = torch.compile(region)(x) ... # Backward pass runs under no autocast. z.backward()
还有第三个选项。如果将
torch._functorch.config.backward_pass_autocast设置为 kwargs 列表,我们将假定反向传播在由这些 kwargs 构建的 autocast 上下文下运行。例如,如果您的代码如下所示:
y = torch.compile(region)(x) ... # Backward pass runs under special context manager with torch.amp.autocast(**kwargs): z.backward()
则将
torch._functorch.config.backward_pass_autocast = kwargs设置为。
使用 patch 将选项应用于特定的 torch.compile 调用。
with torch.amp.autocast(...):
with torch._functorch.config.patch(backward_pass_autocast="same_as_forward")
y = torch.compile(region)(x)
...
# backward pass run under the same autocast context as the compiled region
z.backward()