评价此页
fullgraph=False">

嵌套图中断#

创建时间:2025 年 7 月 28 日 | 最后更新时间:2025 年 7 月 28 日

摘要

  • 嵌套函数中的图中断可能导致编译器行为难以理解,我们将在下面进行文档说明。

  • 嵌套图中断会导致 O(N)\mathcal O(N) 重复图中断行为。

回想一下,当 torch.compile 应用于一个函数时,任何嵌套的函数调用也会被跟踪。嵌套图中断 指的是发生在嵌套函数调用中的任何图中断。

def inner(x):
    ...
    torch._dynamo.graph_break()  # nested graph break
    ...

@torch.compile
def outer(x):
    ...
    y = inner(x)
    ...

嵌套图中断的恢复语义可能令人困惑,因此我们在此描述其行为。

回想一下,在 fullgraph=False 中,图中断会被处理,即编译到目前为止确定的 FX 图,以常规 Python 运行不支持的代码,然后在新 FX 图中恢复跟踪。恢复函数实际上是一项相当复杂的技术壮举,因此恢复跟踪仅支持顶级函数。

因此,我们可以按照以下方式在嵌套图中断后恢复跟踪(在此限制下):

首先,考虑下面的示例,其中 torch.compilef 开始跟踪,并一直跟踪直到遇到 inner1 中的图中断。

def inner1(x):
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

def inner2(x):
    x = x + 4
    x = inner1(x)
    x = x + 8

@torch.compile
def f(x):
    # start tracing from here
    x = x + 16
    x = inner2(x)
    x = x + 32

f(torch.randn(3))

由于我们只能从顶级函数恢复,因此我们在 f 中对 inner2 的调用进行图中断。

# The semantics of torch.compile(f)(x) is roughly this:
def compiled_f_semantics(x):
    y = x + 16
    z = inner2(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

compiled_f_semantics(torch.randn(3))

inner2 然后会自动编译为顶级函数。我们一直跟踪直到再次遇到 inner1 中的图中断。

def inner1(x):
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

# this torch.compile is automatically applied
@torch.compile
def inner2(x):
    # start tracing from here
    x = x + 4
    x = inner1(x)
    x = x + 8

def compiled_f_semantics(x):
    y = x + 16
    z = inner2(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

compiled_f_semantics(torch.randn(3))

然后,我们在 inner2 中对 inner1 的调用进行图中断。

def compiled_inner2_semantics(x):
    y = x + 4
    z = inner1(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

inner1 然后会自动编译为顶级函数。图中断来自 inner1,因此我们正常处理该图中断。

# this torch.compile is automatically applied
@torch.compile
def inner1(x):
    # start tracing from here
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

def compiled_f_semantics(x):
    y = x + 16
    z = compiled_inner2_semantics(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

def compiled_inner2_semantics(x):
    y = x + 4
    z = inner1(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

compiled_f_semantics(torch.randn(3))

inner1 被正常处理。

def compiled_inner1_semantics(x):
    y = x + 1
    torch._dynamo.graph_break()
    return torch.compile(resume_inner1_semantics)(y)

def resume_inner1_semantics(x):
    return x + 2

因此,初始代码在语义上等同于:

def compiled_f_semantics(x):
    y = x + 16
    z = compiled_inner2_semantics(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

def compiled_inner2_semantics(x):
    y = x + 4
    z = compiled_inner1_semantics(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

def compiled_inner1_semantics(x):
    y = x + 1
    torch._dynamo.graph_break()
    return torch.compile(resume_inner1_semantics)(y)

def resume_inner1_semantics(x):
    return x + 2

compiled_f_semantics(torch.randn(3))

请特别注意,我们跟踪了 3 个顶级函数,并且跟踪了相同的图中断 3 次。这就是为什么在使用 torch.compile 时可能会遇到重复图中断的原因。

总而言之,嵌套图中断的处理方式如下:

  • 从顶级函数一直跟踪到嵌套的图中断。

  • 在顶级函数中,在调用二级函数时进行图中断。

  • 编译到目前为止跟踪到的 PyTorch 操作并运行编译后的图。

  • 调用二级函数,该函数会被自动编译为顶级函数。

  • 在调用二级函数后恢复跟踪。

请注意,处理此图中断的运行时为 O(NK)\mathcal O(NK),其中 NN 是嵌套深度,KK 是从顶级函数到图中断的指令数。我们最终会跟踪 O(N2)\mathcal O(N^2) 帧,并且我们跟踪相同的图中断 O(N)\mathcal O(N) 次。