嵌套图中断#
创建时间:2025 年 7 月 28 日 | 最后更新时间:2025 年 7 月 28 日
摘要
嵌套函数中的图中断可能导致编译器行为难以理解,我们将在下面进行文档说明。
嵌套图中断会导致 重复图中断行为。
回想一下,当 torch.compile 应用于一个函数时,任何嵌套的函数调用也会被跟踪。嵌套图中断 指的是发生在嵌套函数调用中的任何图中断。
def inner(x):
...
torch._dynamo.graph_break() # nested graph break
...
@torch.compile
def outer(x):
...
y = inner(x)
...
嵌套图中断的恢复语义可能令人困惑,因此我们在此描述其行为。
回想一下,在 fullgraph=False 中,图中断会被处理,即编译到目前为止确定的 FX 图,以常规 Python 运行不支持的代码,然后在新 FX 图中恢复跟踪。恢复函数实际上是一项相当复杂的技术壮举,因此恢复跟踪仅支持顶级函数。
因此,我们可以按照以下方式在嵌套图中断后恢复跟踪(在此限制下):
首先,考虑下面的示例,其中 torch.compile 从 f 开始跟踪,并一直跟踪直到遇到 inner1 中的图中断。
def inner1(x):
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
def inner2(x):
x = x + 4
x = inner1(x)
x = x + 8
@torch.compile
def f(x):
# start tracing from here
x = x + 16
x = inner2(x)
x = x + 32
f(torch.randn(3))
由于我们只能从顶级函数恢复,因此我们在 f 中对 inner2 的调用进行图中断。
# The semantics of torch.compile(f)(x) is roughly this:
def compiled_f_semantics(x):
y = x + 16
z = inner2(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
compiled_f_semantics(torch.randn(3))
inner2 然后会自动编译为顶级函数。我们一直跟踪直到再次遇到 inner1 中的图中断。
def inner1(x):
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
# this torch.compile is automatically applied
@torch.compile
def inner2(x):
# start tracing from here
x = x + 4
x = inner1(x)
x = x + 8
def compiled_f_semantics(x):
y = x + 16
z = inner2(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
compiled_f_semantics(torch.randn(3))
然后,我们在 inner2 中对 inner1 的调用进行图中断。
def compiled_inner2_semantics(x):
y = x + 4
z = inner1(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
inner1 然后会自动编译为顶级函数。图中断来自 inner1,因此我们正常处理该图中断。
# this torch.compile is automatically applied
@torch.compile
def inner1(x):
# start tracing from here
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
def compiled_f_semantics(x):
y = x + 16
z = compiled_inner2_semantics(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
def compiled_inner2_semantics(x):
y = x + 4
z = inner1(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
compiled_f_semantics(torch.randn(3))
inner1 被正常处理。
def compiled_inner1_semantics(x):
y = x + 1
torch._dynamo.graph_break()
return torch.compile(resume_inner1_semantics)(y)
def resume_inner1_semantics(x):
return x + 2
因此,初始代码在语义上等同于:
def compiled_f_semantics(x):
y = x + 16
z = compiled_inner2_semantics(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
def compiled_inner2_semantics(x):
y = x + 4
z = compiled_inner1_semantics(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
def compiled_inner1_semantics(x):
y = x + 1
torch._dynamo.graph_break()
return torch.compile(resume_inner1_semantics)(y)
def resume_inner1_semantics(x):
return x + 2
compiled_f_semantics(torch.randn(3))
请特别注意,我们跟踪了 3 个顶级函数,并且跟踪了相同的图中断 3 次。这就是为什么在使用 torch.compile 时可能会遇到重复图中断的原因。
总而言之,嵌套图中断的处理方式如下:
从顶级函数一直跟踪到嵌套的图中断。
在顶级函数中,在调用二级函数时进行图中断。
编译到目前为止跟踪到的 PyTorch 操作并运行编译后的图。
调用二级函数,该函数会被自动编译为顶级函数。
在调用二级函数后恢复跟踪。
请注意,处理此图中断的运行时为 ,其中 是嵌套深度, 是从顶级函数到图中断的指令数。我们最终会跟踪 帧,并且我们跟踪相同的图中断 次。