切换 error_on_graph_break
#
创建日期:2025 年 9 月 3 日 | 最后更新日期:2025 年 9 月 3 日
摘要
当
fullgraph=False
时,我们可以使用torch._dynamo.error_on_graph_break()
来更灵活地处理图中断。
到目前为止,我们已经介绍了两种处理 torch.compile
中图中断的方法。
fullgraph=True
会在第一次图中断时报错,并额外保证只从代码中跟踪一个图。fullgraph=False
在遇到图中断时也会继续跟踪。
如果我们希望大部分代码都不允许图中断,但有少数有问题的函数,这些函数的图中断很难消除,而我们对此是可以接受的,该怎么办?我们可以使用 torch._dynamo.error_on_graph_break()
来实现此目的。
torch.compile
有一个 error_on_graph_break
设置(初始值为 False
)。如果在 error_on_graph_break
设置为 False
的情况下,在代码中发生图中断或编译器错误,那么 torch.compile
将尝试在图中断/错误后继续编译。如果 error_on_graph_break
设置为 True
,那么 torch.compile
将中止编译并将错误传播给用户代码。
error_on_graph_break=True
和 fullgraph=True
之间的显著区别在于,前者**不保证捕获单个图**。error_on_graph_break
**可以在编译时任意切换**,通过使用 torch._dynamo.error_on_graph_break()
上下文管理器/装饰器。相比之下,一旦 fullgraph
设置为 True
,就无法将其重新设置为 False
。最后,error_on_graph_break
的优先级低于 fullgraph
— error_on_graph_break
仅在 fullgraph=False
时生效。
error_on_graph_break(False)
示例#
@torch._dynamo.error_on_graph_break(False)
def code_with_a_difficult_graph_break(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner(x):
return code_with_a_difficult_graph_break(x)
# NOTE: fullgraph=False
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return inner(x)
# No error, but there is a graph break
fn(torch.randn(3))
Graph break in user code at /tmp/ipykernel_880/1452578661.py:4
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
app.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
self.io_loop.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start
self.asyncio_loop.run_forever()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 519, in dispatch_queue
await self.process_one()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 508, in process_one
await dispatch(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
await result
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 368, in execute_request
await super().execute_request(stream, ident, parent)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
reply_content = await reply_content
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 455, in do_execute
res = shell.run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 577, in run_cell
return super().run_cell(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell
result = self._run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell
result = runner(coro)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_880/1452578661.py", line 17, in <module>
fn(torch.randn(3))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "/tmp/ipykernel_880/1452578661.py", line 14, in fn
return inner(x)
File "/tmp/ipykernel_880/1452578661.py", line 8, in inner
return code_with_a_difficult_graph_break(x)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "/tmp/ipykernel_880/1452578661.py", line 4, in code_with_a_difficult_graph_break
torch._dynamo.graph_break()
Graph break (user stack suppressed due to duplicate graph break) in user code at /tmp/ipykernel_880/1452578661.py:4
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
Graph break (user stack suppressed due to duplicate graph break) in user code at /tmp/ipykernel_880/1452578661.py:4
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
Graph break (user stack suppressed due to duplicate graph break) in user code at /tmp/ipykernel_880/1452578661.py:4
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
tensor([1.5304, 1.4924, 0.9838])
在 error_on_graph_break(True)
下使用 error_on_graph_break(False)
,有助于在我们希望最小化图中断(即遵循 fullgraph=True
编程模型)但存在某些难以规避的、非性能关键的图中断代码段时。
error_on_graph_break()
也可以用作上下文管理器。
# NOTE: fullgraph=False
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break() # no error
return x + 2
# No error, but there is a graph break
fn(torch.randn(3))
Graph break in user code at /tmp/ipykernel_880/737485247.py:7
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
app.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
self.io_loop.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start
self.asyncio_loop.run_forever()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 519, in dispatch_queue
await self.process_one()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 508, in process_one
await dispatch(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
await result
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 368, in execute_request
await super().execute_request(stream, ident, parent)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
reply_content = await reply_content
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 455, in do_execute
res = shell.run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 577, in run_cell
return super().run_cell(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell
result = self._run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell
result = runner(coro)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_880/737485247.py", line 11, in <module>
fn(torch.randn(3))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "/tmp/ipykernel_880/737485247.py", line 7, in fn
torch._dynamo.graph_break() # no error
tensor([4.5572, 3.4907, 1.6146])
您可以使用猴子补丁(monkey patching)来切换 error_on_graph_break
,以处理您无法编辑源代码的代码(例如,框架代码)。
class ThirdPartyModule(torch.nn.Module):
def forward(self, x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
tp_mod = ThirdPartyModule()
tp_mod.forward = torch._dynamo.error_on_graph_break(False)(tp_mod.forward)
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return tp_mod.forward(x)
# No error, but there is a graph break
fn(torch.randn(3))
Graph break in user code at /tmp/ipykernel_880/2112598647.py:4
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
app.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
self.io_loop.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start
self.asyncio_loop.run_forever()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 519, in dispatch_queue
await self.process_one()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 508, in process_one
await dispatch(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
await result
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 368, in execute_request
await super().execute_request(stream, ident, parent)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
reply_content = await reply_content
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 455, in do_execute
res = shell.run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 577, in run_cell
return super().run_cell(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell
result = self._run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell
result = runner(coro)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_880/2112598647.py", line 16, in <module>
fn(torch.randn(3))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "/tmp/ipykernel_880/2112598647.py", line 13, in fn
return tp_mod.forward(x)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "/tmp/ipykernel_880/2112598647.py", line 4, in forward
torch._dynamo.graph_break()
Graph break (user stack suppressed due to duplicate graph break) in user code at /tmp/ipykernel_880/2112598647.py:4
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
Graph break (user stack suppressed due to duplicate graph break) in user code at /tmp/ipykernel_880/2112598647.py:4
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
tensor([1.7608, 2.9385, 2.1531])
error_on_graph_break(True)
示例#
@torch._dynamo.error_on_graph_break(True)
def inner2(x):
x = x + 1
torch._dynamo.graph_break() # error
return x + 2
def inner(x):
return inner2(x)
# fullgraph=False, error_on_graph_break=False
@torch.compile
def fn(x):
x = x + 4
torch._dynamo.graph_break() # no error
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
from user code:
File "/tmp/ipykernel_880/2379101916.py", line 15, in torch_dynamo_resume_in_fn_at_14
return inner(x)
File "/tmp/ipykernel_880/2379101916.py", line 8, in inner
return inner2(x)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "/tmp/ipykernel_880/2379101916.py", line 4, in inner2
torch._dynamo.graph_break() # error
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Graph break in user code at /tmp/ipykernel_880/2379101916.py:14
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
app.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
self.io_loop.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start
self.asyncio_loop.run_forever()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 519, in dispatch_queue
await self.process_one()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 508, in process_one
await dispatch(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
await result
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 368, in execute_request
await super().execute_request(stream, ident, parent)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
reply_content = await reply_content
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 455, in do_execute
res = shell.run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 577, in run_cell
return super().run_cell(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell
result = self._run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell
result = runner(coro)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_880/2379101916.py", line 18, in <module>
fn(torch.randn(3))
File "/tmp/ipykernel_880/2379101916.py", line 14, in fn
torch._dynamo.graph_break() # no error
在 error_on_graph_break(False)
下使用 error_on_graph_break(True)
,有助于在我们希望灵活使用 torch.compile
(即遵循 fullgraph=False
编程模型)但某些代码段是性能关键的,并且我们希望确保这些代码段不包含图中断时。
error_on_graph_break
嵌套行为#
torch._dynamo.error_on_graph_break()
也会影响嵌套调用的 error_on_graph_break
设置。
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner2(x):
with torch._dynamo.error_on_graph_break(False):
return inner(x)
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return inner2(x)
# no error
fn(torch.randn(3))
Graph break in user code at /tmp/ipykernel_880/1007149706.py:3
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
app.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
self.io_loop.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start
self.asyncio_loop.run_forever()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 519, in dispatch_queue
await self.process_one()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 508, in process_one
await dispatch(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
await result
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 368, in execute_request
await super().execute_request(stream, ident, parent)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
reply_content = await reply_content
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 455, in do_execute
res = shell.run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 577, in run_cell
return super().run_cell(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell
result = self._run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell
result = runner(coro)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_880/1007149706.py", line 16, in <module>
fn(torch.randn(3))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "/tmp/ipykernel_880/1007149706.py", line 13, in fn
return inner2(x)
File "/tmp/ipykernel_880/1007149706.py", line 8, in inner2
return inner(x)
File "/tmp/ipykernel_880/1007149706.py", line 3, in inner
torch._dynamo.graph_break()
Graph break (user stack suppressed due to duplicate graph break) in user code at /tmp/ipykernel_880/1007149706.py:3
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
Graph break (user stack suppressed due to duplicate graph break) in user code at /tmp/ipykernel_880/1007149706.py:3
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
tensor([3.5942, 2.0674, 2.8445])
torch._dynamo.error_on_graph_break()
可以在另一个 torch._dynamo.error_on_graph_break()
区域下使用。
def inner(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break()
return x + 2
def inner2(x):
with torch._dynamo.error_on_graph_break(True):
return inner(x)
@torch.compile
def fn(x):
return inner2(x)
# no error
fn(torch.randn(3))
Graph break in user code at /tmp/ipykernel_880/1343774799.py:4
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
app.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
self.io_loop.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start
self.asyncio_loop.run_forever()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 519, in dispatch_queue
await self.process_one()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 508, in process_one
await dispatch(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
await result
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 368, in execute_request
await super().execute_request(stream, ident, parent)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
reply_content = await reply_content
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 455, in do_execute
res = shell.run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 577, in run_cell
return super().run_cell(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell
result = self._run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell
result = runner(coro)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_880/1343774799.py", line 16, in <module>
fn(torch.randn(3))
File "/tmp/ipykernel_880/1343774799.py", line 13, in fn
return inner2(x)
File "/tmp/ipykernel_880/1343774799.py", line 9, in inner2
return inner(x)
File "/tmp/ipykernel_880/1343774799.py", line 4, in inner
torch._dynamo.graph_break()
Graph break (user stack suppressed due to duplicate graph break) in user code at /tmp/ipykernel_880/1343774799.py:4
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
Graph break (user stack suppressed due to duplicate graph break) in user code at /tmp/ipykernel_880/1343774799.py:4
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
tensor([1.3759, 2.1866, 3.8692])
与 fullgraph
的交互#
fullgraph=True
的优先级高于 error_on_graph_break
。
@torch._dynamo.error_on_graph_break(False)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=True)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
from user code:
File "/tmp/ipykernel_880/2331424258.py", line 9, in fn
return inner(x)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "/tmp/ipykernel_880/2331424258.py", line 3, in inner
x = x + 1
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
fullgraph=True
无法切换回 fullgraph=False
。
@torch.compile(fullgraph=False)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=True)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
from user code:
File "/tmp/ipykernel_880/262151723.py", line 9, in fn
return inner(x)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/polyfills/__init__.py", line 264, in getattr_and_trace
return fn(*args[2:], **kwargs)
File "/tmp/ipykernel_880/262151723.py", line 3, in inner
x = x + 1
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
@torch.compile(fullgraph=True)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=False)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
from user code:
File "/tmp/ipykernel_880/2119173801.py", line 3, in inner
x = x + 1
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Graph break in user code at /tmp/ipykernel_880/2119173801.py:3
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
app.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
self.io_loop.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start
self.asyncio_loop.run_forever()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 519, in dispatch_queue
await self.process_one()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 508, in process_one
await dispatch(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
await result
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 368, in execute_request
await super().execute_request(stream, ident, parent)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
reply_content = await reply_content
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 455, in do_execute
res = shell.run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 577, in run_cell
return super().run_cell(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell
result = self._run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell
result = runner(coro)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_880/2119173801.py", line 12, in <module>
fn(torch.randn(3))
File "/tmp/ipykernel_880/2119173801.py", line 9, in fn
return inner(x)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/polyfills/__init__.py", line 264, in getattr_and_trace
return fn(*args[2:], **kwargs)
File "/tmp/ipykernel_880/2119173801.py", line 3, in inner
x = x + 1
fullgraph=True/False
与 error_on_graph_break
总结#
下表总结了 fullgraph=True/False
和 error_on_graph_break
之间的区别。
|
|
|
---|---|---|
|
图中断会导致错误。只报告第一次图中断。保证一个图。 |
与 |
|
图中断会导致错误。只报告第一次图中断。不保证一个图。 |
遇到图中断后会继续编译。所有图中断都会被报告。 |