评价此页
torch._dynamo.nonstrict_trace">

使用 torch._dynamo.nonstrict_trace#

创建时间:2025 年 7 月 28 日 | 最后更新时间:2025 年 7 月 28 日

摘要

  • 使用 nonstrict_tracetorch.compile 编译区域内部使用非严格跟踪来跟踪函数。您可能希望这样做,因为 Dynamo 图在函数内部的某个地方断裂了,而您确定该函数是可进行非严格跟踪的。

考虑以下场景

def get_magic_num():
    # This explicit graph break call is meant to emulate any kind of Dynamo
    # graph break, e.g., the function is implemented in C, or uses some python
    # language feature Dynamo doesn't yet support.
    torch._dynamo.graph_break()
    return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
    n = get_magic_num()
    return x + n
try:
    func(torch.rand(10))
except Exception as e:
    print(e)
Call to `torch._dynamo.graph_break()`
  Explanation: User-inserted graph break. Message: None
  Hint: Remove the `torch._dynamo.graph_break()` call.

  Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html

from user code:
   File "/tmp/ipykernel_850/2253748958.py", line 9, in func
    n = get_magic_num()
  File "/tmp/ipykernel_850/2253748958.py", line 5, in get_magic_num
    torch._dynamo.graph_break()

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

如果我们运行上面的代码,我们将收到一个来自 Dynamo 的错误,因为尽管用户指定了 fullgraph=True,但它仍然看到一个图断裂。

在这些情况下,如果用户仍然希望保持 fullgraph=True,他们通常有几种选择

  1. 图断裂是由于 Dynamo 尚不支持的语言特性。在这种情况下,用户要么重写他们的代码,要么在 GitHub 上提交一个 issue。

  2. 图断裂是由于调用了用 C 实现的函数。在这种情况下,用户可以尝试使用自定义操作。用户也可以尝试提供一个 polyfill(Python 中的引用实现),以便 Dynamo 可以跟踪它。

  3. 最坏的情况——内部编译器错误。在这种情况下,用户很可能需要在 GitHub 上提交一个 issue。

除了所有这些选项之外,PyTorch 还提供了一个替代方案 torch._dynamo.nonstrict_trace,前提是引发图断裂的函数调用满足某些要求

  • 通用非严格跟踪 的要求。

  • 输入和输出必须包含基本类型(例如,intfloatlistdicttorch.Tensor),或者已注册到 torch.utils._pytree 的用户定义类型。

  • 该函数必须定义在 torch.compile 编译区域之外。

  • 函数读取的任何非输入值将被视为常量(例如,全局张量),并且不会对其进行保护。

在跟踪对 torch._dynamo.nonstrict_trace 跟踪的函数的调用时,torch.compile 会切换到非严格跟踪,并且 FX 图最终将包含该函数内部发生的所有相关张量操作。

对于上面的示例,我们可以使用 torch._dynamo.nonstrict_trace 消除 图断裂

@torch._dynamo.nonstrict_trace
def get_magic_num():
    # This explicit graph break call is meant to emulate any kind of Dynamo
    # graph break, e.g., the function is implemented in C, or uses some python
    # language feature Dynamo doesn't yet support.
    torch._dynamo.graph_break()
    return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
    n = get_magic_num()
    return x + n
print(func(torch.rand(10)))
# No graph break and no error.
tensor([42.1627, 42.1384, 42.9075, 42.5242, 42.4176, 42.1747, 42.9599, 42.2383,
        42.5449, 42.0285])

请注意,也可以在 torch.compile 编译区域内部使用它

def get_magic_num():
    # This explicit graph break call is meant to emulate any kind of Dynamo
    # graph break, e.g., the function is implemented in C, or uses some python
    # language feature Dynamo doesn't yet support.
    torch._dynamo.graph_break()
    return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
    n = torch._dynamo.nonstrict_trace(get_magic_num)()
    return x + n
print(func(torch.rand(10)))
# No graph break and no error.
tensor([42.5935, 42.2370, 42.2154, 42.5488, 42.9691, 42.9799, 42.1668, 42.0909,
        42.4228, 42.2204])