使用 torch._dynamo.nonstrict_trace
#
创建时间:2025 年 7 月 28 日 | 最后更新时间:2025 年 7 月 28 日
摘要
使用
nonstrict_trace
在torch.compile
编译区域内部使用非严格跟踪来跟踪函数。您可能希望这样做,因为 Dynamo 图在函数内部的某个地方断裂了,而您确定该函数是可进行非严格跟踪的。
考虑以下场景
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = get_magic_num()
return x + n
try:
func(torch.rand(10))
except Exception as e:
print(e)
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
from user code:
File "/tmp/ipykernel_850/2253748958.py", line 9, in func
n = get_magic_num()
File "/tmp/ipykernel_850/2253748958.py", line 5, in get_magic_num
torch._dynamo.graph_break()
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
如果我们运行上面的代码,我们将收到一个来自 Dynamo 的错误,因为尽管用户指定了 fullgraph=True
,但它仍然看到一个图断裂。
在这些情况下,如果用户仍然希望保持 fullgraph=True
,他们通常有几种选择
图断裂是由于 Dynamo 尚不支持的语言特性。在这种情况下,用户要么重写他们的代码,要么在 GitHub 上提交一个 issue。
图断裂是由于调用了用 C 实现的函数。在这种情况下,用户可以尝试使用自定义操作。用户也可以尝试提供一个 polyfill(Python 中的引用实现),以便 Dynamo 可以跟踪它。
最坏的情况——内部编译器错误。在这种情况下,用户很可能需要在 GitHub 上提交一个 issue。
除了所有这些选项之外,PyTorch 还提供了一个替代方案 torch._dynamo.nonstrict_trace
,前提是引发图断裂的函数调用满足某些要求
通用非严格跟踪 的要求。
输入和输出必须包含基本类型(例如,
int
、float
、list
、dict
、torch.Tensor
),或者已注册到torch.utils._pytree
的用户定义类型。该函数必须定义在
torch.compile
编译区域之外。函数读取的任何非输入值将被视为常量(例如,全局张量),并且不会对其进行保护。
在跟踪对 torch._dynamo.nonstrict_trace
跟踪的函数的调用时,torch.compile
会切换到非严格跟踪,并且 FX 图最终将包含该函数内部发生的所有相关张量操作。
对于上面的示例,我们可以使用 torch._dynamo.nonstrict_trace 来 消除
图断裂
@torch._dynamo.nonstrict_trace
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = get_magic_num()
return x + n
print(func(torch.rand(10)))
# No graph break and no error.
tensor([42.1627, 42.1384, 42.9075, 42.5242, 42.4176, 42.1747, 42.9599, 42.2383,
42.5449, 42.0285])
请注意,也可以在 torch.compile
编译区域内部使用它
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = torch._dynamo.nonstrict_trace(get_magic_num)()
return x + n
print(func(torch.rand(10)))
# No graph break and no error.
tensor([42.5935, 42.2370, 42.2154, 42.5488, 42.9691, 42.9799, 42.1668, 42.0909,
42.4228, 42.2204])