使用 fullgraph=True
识别和消除图中断#
创建时间:2025 年 7 月 28 日 | 最后更新时间:2025 年 7 月 28 日
使用 torch.compile(fullgraph=False)
(默认值)是开始使用 torch.compile
的好方法:它开箱即用地支持所有 Python 程序,方法是允许图中断,并在常见情况下提供良好的性能。
但是,如果您试图从模型中获得更多性能,您应该明确考虑哪些代码区域应该被编译
我们建议使用
torch.compile(fullgraph=True)
来查找和消除代码中的图中断。如果您是库开发者(或正在测试您的代码是否“可以”与
torch.compile
一起使用),我们建议使用torch.compile(fullgraph=True)
进行测试。
torch.compile(fullgraph=True)
相比 fullgraph=False
提供了更强的保证:我们将始终捕获一个单独的 FX 图进行编译(如果由于图中断而无法编译,则会出错)。**特别是,您将被迫解决遇到的每一个图中断。**
解决图中断有多种策略。
策略 1:重写不支持的代码,使其使用 Dynamo 支持的功能#
许多图中断错误消息会提供一些关于如何重写代码以避免图中断的建议。如果图中断仍然难以解决,请继续进行下一策略,或在 PyTorch GitHub 仓库 提交一个 issue。
更多图中断示例以及如何解决它们,请参见 常见的图中断。
示例:Dynamo 不支持对被编译函数的输入 list_iterator
对象调用 next
。
@torch.compile(fullgraph=True)
def f(xs):
a = next(xs)
b = next(xs)
return a + b
xs = [torch.tensor(1.), torch.tensor(2.)]
try:
out = f(iter(xs))
except Exception as e:
print(e)
Unsupported method call
Explanation: Dynamo does not know how to trace method `__next__` of class `list_iterator`
Hint: Avoid calling `list_iterator.__next__` in your code.
Hint: Please report an issue to PyTorch.
Hint: Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). This can happen unintentionally if a previous graph break happens with a builtin iterator in the local scope.
Hint: List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, (2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a function, or (4) use Python 3.12+.
Developer debug context: call_method UserDefinedObjectVariable(list_iterator) __next__ [] {}
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0156.html
from user code:
File "/tmp/ipykernel_904/1195637716.py", line 3, in f
a = next(xs)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
相反,重写编译函数以接受列表。
@torch.compile(fullgraph=True)
def f_rewritten(xs):
it = iter(xs)
a = next(it)
b = next(it)
return a + b
f_rewritten(xs)
tensor(3.)
策略 2:纯函数始终可以通过逃生舱进行编译#
**摘要**:所有 Python 函数的空间都非常广阔,因此 Dynamo 无法在没有图中断的情况下跟踪每个 Python 函数。对于 Dynamo 无法在没有图中断的情况下跟踪的“纯”Python 函数,我们提供了一些逃生舱来尝试跟踪这些函数。
对纯 Triton 内核使用
custom_op
或triton_op
。对仅使用 PyTorch Tensor 运算的纯函数使用
nonstrict_trace
。对所有其他纯函数使用
custom_op
。
“纯函数”是具有以下属性的函数
确定性。给定相同的输入,纯函数总是返回相同的输出。
无外部副作用。纯函数没有任何外部可见的副作用,例如修改外部状态或执行 I/O 操作。函数内部的副作用是允许的(例如,突变中间张量)。一个值得注意的例外是,函数输入张量上的
torch.*
运算的突变通常是允许的。显式输入/输出。所有输入数据都必须通过函数参数传递,并且所有输出都从函数返回。
有关示例,请参见 纯函数。
理论上,Dynamo 能够处理各种各样的非纯函数,但可能缺少对特定 Python 语言功能的覆盖。然而,纯函数总是可以通过逃生舱进行编译。
如果您有图中断,可以将围绕它的代码重构为纯函数,并使用绕过 Dynamo 跟踪的逃生舱。
如果您希望函数中的 Tensor 运算显示在 Dynamo 输出图中(从而可以优化),请使用
torch._dynamo.nonstrict_trace
。nonstrict_trace
告诉 Dynamo 使用**非严格跟踪**。如果您希望函数相对于
torch.compile
(包括前端 Dynamo 和后端)是“不透明”的,请使用自定义运算符。
请注意,没有任何东西阻止这些逃生舱应用于非纯函数,但**我们不提供任何健全性保证**。
示例:如果 Dynamo 不支持某些 Python 功能或 API(例如,它使用 PyTorch 运算),并且该功能是严格可跟踪的,请 使用 torch._dynamo.nonstrict_trace
来捕获它。
# this is a function that Dynamo doesn't support (due to the graph_break() call).
def g(x):
y = x.sin()
torch._dynamo.graph_break()
z = y.sin()
return z
@torch.compile(fullgraph=True)
def f(x):
w = x.sin()
return g(w)
x = torch.randn(3)
try:
f(x) # Graph Break: there was a call to torch._dynamo.graph_break()
except Exception as e:
print(e)
@torch.compile(fullgraph=True)
def f_rewritten(x):
w = x.sin()
return torch._dynamo.nonstrict_trace(g)(w)
f_rewritten(x) # works
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
from user code:
File "/tmp/ipykernel_904/2422769198.py", line 11, in f
return g(w)
File "/tmp/ipykernel_904/2422769198.py", line 4, in g
torch._dynamo.graph_break()
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
tensor([-0.0326, -0.7442, -0.2731])
示例:使用 自定义运算符 来创建相对于 torch.compile
不透明的函数。
from torch.utils.cpp_extension import load_inline
# C++ source code for the square operation
cpp_source = """
torch::Tensor square_cpu(torch::Tensor input) {
// Check that input is a CPU tensor
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
// Create output tensor with same shape and dtype as input
torch::Tensor output = torch::empty_like(input);
// Get data pointers
float* input_data = input.data_ptr<float>();
float* output_data = output.data_ptr<float>();
// Get total number of elements
int64_t numel = input.numel();
// For loop to compute square of each element
for (int64_t i = 0; i < numel; i++) {
output_data[i] = input_data[i] * input_data[i];
}
return output;
}
"""
# Load the extension inline
square_module = load_inline(
name="square_cpu_kernel",
cpp_sources=cpp_source,
functions=["square_cpu"],
verbose=True
)
def square(x):
return square_module.square_cpu(x)
@torch.compile(fullgraph=True)
def f(x):
return square(x)
try:
f(torch.randn(3, 3)) # graph break
except Exception as e:
print(e)
[1/2] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=square_cpu_kernel -DTORCH_API_INCLUDE_EXTENSION_H -isystem /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/include -isystem /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/envs/py_3.10/include/python3.10 -fPIC -std=c++17 -c /var/lib/jenkins/.cache/torch_extensions/py310_cpu/square_cpu_kernel/main.cpp -o main.o
[2/2] c++ main.o -shared -L/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib -lc10 -ltorch_cpu -ltorch -ltorch_python -o square_cpu_kernel.so
Attempted to call function marked as skipped
Explanation: Dynamo does not know how to trace the builtin `square_cpu_kernel.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.square_cpu.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.ac.cn/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
Developer debug context: module: square_cpu_kernel, qualname: pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.square_cpu, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html
from user code:
File "/tmp/ipykernel_904/2059008136.py", line 41, in f
return square(x)
File "/tmp/ipykernel_904/2059008136.py", line 37, in square
return square_module.square_cpu(x)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:1598: UserWarning: Dynamo does not know how to trace the builtin `square_cpu_kernel.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.square_cpu.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.ac.cn/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
# Use torch.library.custom_op to define a new custom operator.
# Custom operators are opaque with respect to torch.compile:
# that is, torch.compile does not peek into them.
@torch.library.custom_op("mylib::square", mutates_args=())
def square(x: torch.Tensor) -> torch.Tensor:
return square_module.square_cpu(x)
# Use register_fake to add a ``FakeTensor`` kernel for the operator
@square.register_fake
def _(x):
return x.new_empty(x.size())
print(f(torch.randn(3, 3))) # no graph break
tensor([[1.2862e-01, 8.5591e-02, 2.3450e-01],
[1.5921e-01, 4.4706e-01, 1.5394e+00],
[8.1086e-04, 4.4906e-01, 5.1608e-01]])
有关自定义 Triton 内核的 triton_op
的更多信息,请参见 用户定义的 Triton 内核教程。
策略 3:不要编译代码#
并非所有代码都适合编译。torch.compile
是一个用于 Tensor 计算的编译器;它无法优化磁盘 I/O 等内容。尝试重构代码,使不受支持的代码不会在编译区域内被调用。
@torch.compile(fullgraph=True)
def f(x):
y = x ** 2 / 2
torch.save(y, "foo.pt")
z = y ** 3 / 6
return z
x = torch.randn(3)
try:
f(x) # Graph Break: torch.save not supported
except Exception as e:
print(e)
Attempted to call function marked as skipped
Explanation: Dynamo developers have intentionally marked that the function `save` in file `/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/serialization.py` should not be traced.
Hint: Avoid calling the function `save`.
Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `save` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.
Hint: Please file an issue to PyTorch.
Developer debug context: module: torch.serialization, qualname: save, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html
from user code:
File "/tmp/ipykernel_904/150060719.py", line 4, in f
torch.save(y, "foo.pt")
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
def f_rewritten(x):
y = g(x)
torch.save(y, "foo.pt")
z = h(y)
return z
@torch.compile(fullgraph=True)
def g(x):
y = x ** 2 / 2
return y
@torch.compile(fullgraph=True)
def h(y):
z = y ** 3 / 6
return z
f_rewritten(x)
tensor([1.3869e-06, 4.1821e-01, 9.4249e-03])