评价此页
torch.compile 具有不同的 autograd 语义"

torch.compile 具有不同的 autograd 语义#

创建日期:2025 年 6 月 26 日 | 最后更新日期:2025 年 6 月 26 日

当您将 torch.compile 应用到模型前向传播的某个函数时,它会自动为编译后的函数生成一个反向传播。在编译期间,它会为反向传播跟踪出一个图,该图将在每次调用 autograd 时使用。我们将 torch.compile 内部负责此任务的组件称为 AOTDispatcher(有时也称为 AOTAutograd)。

因此,torch.compile 会在前向传播的函数编译过程中,将计算的细节“烘焙”到跟踪出的反向传播图中。然而,在 eager 模式的 PyTorch 中,反向传播是动态的:在前向传播之外,您可以将 tensor.backward()torch.autograd.grad(...) 的调用包装在可能改变其行为的上下文管理器中。

此页面记录了 torch.compile 的 autograd 语义与 eager 模式 PyTorch 的不同之处,以及如何规避这些差异。

Autocast 行为#

torch.compile 会预先假设反向传播是否会在环境 autocast 上下文管理器下运行。默认情况下,使用 torch._functorch.config.backward_pass_autocast 来控制该假设;不正确的假设可能导致静默的错误。

选项包括:

  • "same_as_forward"(默认)。我们假设 torch.compile 编译区域的反向传播将在该区域运行的相同 autocast 上下文管理器下运行(如果存在)。如果您的代码如下所示,请使用此选项:

    with torch.amp.autocast(...):
        y = torch.compile(region)(x)
        ...
        # backward pass run under the same autocast context as the compiled region
        z.backward()
    
  • "off"。我们假设 torch.compile 编译区域的反向传播不会在任何 autocast 上下文管理器下运行。如果您的代码如下所示,请使用此选项:

    with torch.amp.autocast(...):
        y = torch.compile(region)(x)
        ...
    # Backward pass runs under no autocast.
    z.backward()
    
  • 还有第三个选项。如果将 torch._functorch.config.backward_pass_autocast 设置为 kwargs 列表,我们将假定反向传播在由这些 kwargs 构建的 autocast 上下文下运行。

    例如,如果您的代码如下所示:

    y = torch.compile(region)(x)
    ...
    # Backward pass runs under special context manager
    with torch.amp.autocast(**kwargs):
        z.backward()
    

    则将 torch._functorch.config.backward_pass_autocast = kwargs 设置为。

使用 patch 将选项应用于特定的 torch.compile 调用。

with torch.amp.autocast(...):
    with torch._functorch.config.patch(backward_pass_autocast="same_as_forward")
    y = torch.compile(region)(x)
    ...
    # backward pass run under the same autocast context as the compiled region
    z.backward()