注意
跳转至页面底部 下载完整示例代码。
torch.compile 简介#
创建日期:2023 年 3 月 15 日 | 最后更新:2026 年 4 月 01 日 | 最后验证:2024 年 11 月 05 日
作者: William Wen
torch.compile 是加速 PyTorch 代码的新方式!torch.compile 通过将 PyTorch 代码实时(JIT)编译为优化后的内核来提升运行速度,且几乎不需要修改现有代码。
torch.compile 通过跟踪你的 Python 代码并查找其中的 PyTorch 操作来实现这一目标。难以跟踪的代码会导致图中断(graph break),这意味着丧失了优化机会,但并不会导致错误或静默的运行不正确。
torch.compile 可在 PyTorch 2.0 及更高版本中使用。
本简介涵盖了 torch.compile 的基本用法,并展示了相比我们之前的 PyTorch 编译器解决方案 TorchScript,torch.compile 所具备的优势。
若要获取真实模型上的端到端示例,请查阅我们的 torch.compile 端到端教程。
如需排查问题或深入了解如何将 torch.compile 应用于你的代码,请查阅 torch.compile 编程模型。
内容
本教程所需的 pip 依赖项
torch >= 2.0numpyscipy
系统要求 - C++ 编译器,例如 g++ - Python 开发包(python-devel/python-dev)
基本用法#
在本教程中,我们将开启一些日志记录,以帮助了解 torch.compile 在底层所做的工作。以下代码将打印出 torch.compile 所跟踪的 PyTorch 操作。
import torch
torch._logging.set_logs(graph_code=True)
torch.compile 是一个装饰器,可以接收任意 Python 函数。
def foo(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(3, 3), torch.randn(3, 3)))
@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(opt_foo2(torch.randn(3, 3), torch.randn(3, 3)))
TRACED GRAPH
===== __compiled_fn_1_277d075a_04bb_4953_b12e_61f4a1005bdf =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
l_x_ = L_x_
l_y_ = L_y_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:74 in foo, code: a = torch.sin(x)
a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:75 in foo, code: b = torch.cos(y)
b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:76 in foo, code: return a + b
add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
return (add,)
tensor([[-0.7014, 0.0074, 0.1317],
[-0.2300, -0.6917, -0.1366],
[ 0.1762, -0.0748, 1.1971]])
TRACED GRAPH
===== __compiled_fn_3_e7aa85c3_99ac_4c2a_920d_459b28d12e44 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
l_x_ = L_x_
l_y_ = L_y_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:85 in opt_foo2, code: a = torch.sin(x)
a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:86 in opt_foo2, code: b = torch.cos(y)
b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:87 in opt_foo2, code: return a + b
add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
return (add,)
tensor([[-0.3453, -0.5549, -0.0870],
[-0.0455, -0.0700, 0.1640],
[ 1.6199, 1.2302, -0.0657]])
torch.compile 是递归应用的,因此顶层编译函数中嵌套的函数调用也将被编译。
def inner(x):
return torch.sin(x)
@torch.compile
def outer(x, y):
a = inner(x)
b = torch.cos(y)
return a + b
print(outer(torch.randn(3, 3), torch.randn(3, 3)))
TRACED GRAPH
===== __compiled_fn_5_39fd61a4_aad3_4a84_8fa5_95ac598c9f26 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
l_x_ = L_x_
l_y_ = L_y_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:98 in inner, code: return torch.sin(x)
a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:104 in outer, code: b = torch.cos(y)
b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:105 in outer, code: return a + b
add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
return (add,)
tensor([[ 1.0369, -0.2757, 1.0893],
[ 1.4089, 1.2375, 0.7316],
[ 0.9238, 0.8411, 1.1982]])
我们还可以通过调用 torch.nn.Module 实例的 .compile() 方法或直接对模块使用 torch.compile 来对其进行优化。这等同于对模块的 __call__ 方法(它会间接调用 forward)进行 torch.compile 编译。
t = torch.randn(10, 100)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(3, 3)
def forward(self, x):
return torch.nn.functional.relu(self.lin(x))
mod1 = MyModule()
mod1.compile()
print(mod1(torch.randn(3, 3)))
mod2 = MyModule()
mod2 = torch.compile(mod2)
print(mod2(torch.randn(3, 3)))
TRACED GRAPH
===== __compiled_fn_7_a5f50399_ea75_477e_846b_e3a013bd017c =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_self_modules_lin_parameters_weight_: "f32[3, 3][3, 1]cpu", L_self_modules_lin_parameters_bias_: "f32[3][1]cpu", L_x_: "f32[3, 3][3, 1]cpu"):
l_self_modules_lin_parameters_weight_ = L_self_modules_lin_parameters_weight_
l_self_modules_lin_parameters_bias_ = L_self_modules_lin_parameters_bias_
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:126 in forward, code: return torch.nn.functional.relu(self.lin(x))
linear: "f32[3, 3][3, 1]cpu" = torch._C._nn.linear(l_x_, l_self_modules_lin_parameters_weight_, l_self_modules_lin_parameters_bias_); l_x_ = l_self_modules_lin_parameters_weight_ = l_self_modules_lin_parameters_bias_ = None
relu: "f32[3, 3][3, 1]cpu" = torch.nn.functional.relu(linear); linear = None
return (relu,)
tensor([[0.0000, 0.0000, 0.1632],
[0.0000, 0.0000, 0.2538],
[0.0000, 0.0176, 0.1700]], grad_fn=<CompiledFunctionBackward>)
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], grad_fn=<CompiledFunctionBackward>)
演示加速效果#
现在让我们演示 torch.compile 如何加速一个简单的 PyTorch 示例。有关更复杂模型的演示,请参阅我们的 torch.compile 端到端教程。
def foo3(x):
y = x + 1
z = torch.nn.functional.relu(y)
u = z * 2
return u
opt_foo3 = torch.compile(foo3)
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
inp = torch.randn(4096, 4096).cuda()
print("compile:", timed(lambda: opt_foo3(inp))[1])
print("eager:", timed(lambda: foo3(inp))[1])
TRACED GRAPH
===== __compiled_fn_9_0cb2c443_9802_4cf4_b7d8_94ee2e81d6d4 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[4096, 4096][4096, 1]cuda:0"):
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:147 in foo3, code: y = x + 1
y: "f32[4096, 4096][4096, 1]cuda:0" = l_x_ + 1; l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:148 in foo3, code: z = torch.nn.functional.relu(y)
z: "f32[4096, 4096][4096, 1]cuda:0" = torch.nn.functional.relu(y); y = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:149 in foo3, code: u = z * 2
u: "f32[4096, 4096][4096, 1]cuda:0" = z * 2; z = None
return (u,)
compile: 0.402176513671875
eager: 0.03534950256347656
请注意,与 Eager 模式相比,torch.compile 完成任务似乎需要更长的时间。这是因为 torch.compile 在前几次执行时需要额外的时间来编译模型。torch.compile 会尽可能重用已编译的代码,因此如果我们多次运行优化后的模型,应该会看到比 Eager 模式显著的性能提升。
# turn off logging for now to prevent spam
torch._logging.set_logs(graph_code=False)
eager_times = []
for i in range(10):
_, eager_time = timed(lambda: foo3(inp))
eager_times.append(eager_time)
print(f"eager time {i}: {eager_time}")
print("~" * 10)
compile_times = []
for i in range(10):
_, compile_time = timed(lambda: opt_foo3(inp))
compile_times.append(compile_time)
print(f"compile time {i}: {compile_time}")
print("~" * 10)
import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)
eager time 0: 0.0009215999841690063
eager time 1: 0.0008765439987182618
eager time 2: 0.0008765439987182618
eager time 3: 0.0008714240193367004
eager time 4: 0.0008704000115394592
eager time 5: 0.0008714240193367004
eager time 6: 0.0008755199909210205
eager time 7: 0.0008724480271339416
eager time 8: 0.0008734719753265381
eager time 9: 0.0008734719753265381
~~~~~~~~~~
compile time 0: 0.0005631999969482422
compile time 1: 0.0003911679983139038
compile time 2: 0.0003860479891300201
compile time 3: 0.00038707199692726135
compile time 4: 0.00038912001252174375
compile time 5: 0.00038598400354385376
compile time 6: 0.00038502401113510134
compile time 7: 0.0003829759955406189
compile time 8: 0.0003829759955406189
compile time 9: 0.00038912001252174375
~~~~~~~~~~
(eval) eager median: 0.0008734719753265381, compile median: 0.00038655999302864073, speedup: 2.2596026259288076x
~~~~~~~~~~
确实,我们可以看到使用 torch.compile 运行模型带来了显著的加速。加速主要源于减少了 Python 开销和 GPU 读写操作,因此观察到的加速效果可能会因模型架构和批大小(batch size)等因素而异。例如,如果模型架构很简单且数据量很大,那么瓶颈将在于 GPU 计算,观察到的加速可能就不那么明显。
若要查看在真实模型上的加速效果,请查阅我们的 torch.compile 端到端教程。
相比 TorchScript 的优势#
为什么要使用 torch.compile 而不是 TorchScript?主要是因为 torch.compile 能够处理任意 Python 代码,且对现有代码的修改极小。
与拥有跟踪模式(torch.jit.trace)和脚本模式(torch.jit.script)的 TorchScript 相比:跟踪模式容易产生静默的错误,而脚本模式则需要大幅修改代码,并且在遇到不支持的 Python 代码时会报错。
例如,TorchScript 跟踪会在数据依赖的控制流(下方的 if x.sum() < 0: 行)上静默失败,因为它只跟踪实际执行的控制流路径。相比之下,torch.compile 能够正确处理它。
def f1(x, y):
if x.sum() < 0:
return -y
return y
# Test that `fn1` and `fn2` return the same result, given the same arguments `args`.
def test_fns(fn1, fn2, args):
out1 = fn1(*args)
out2 = fn2(*args)
return torch.allclose(out1, out2)
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))
compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)
/var/lib/workspace/intermediate_source/torch_compile_tutorial.py:239: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if x.sum() < 0:
traced 1, 1: True
traced 1, 2: False
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~
TorchScript 脚本模式可以处理数据依赖的控制流,但可能需要重大的代码更改,并且在使用不支持的 Python 语法时会引发错误。
在下面的示例中,我们忘记了 TorchScript 类型注解,导致收到了 TorchScript 错误,因为参数 y 的输入类型(int)与默认参数类型(torch.Tensor)不匹配。相比之下,torch.compile 无需任何类型注解即可正常工作。
import traceback as tb
torch._logging.set_logs(graph_code=True)
def f2(x, y):
return x + y
inp1 = torch.randn(5, 5)
inp2 = 3
script_f2 = torch.jit.script(f2)
try:
script_f2(inp1, inp2)
except:
tb.print_exc()
compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 288, in <module>
script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor
TRACED GRAPH
===== __compiled_fn_18_484f84c1_c0d9_4a00_885e_0e8bb60670cd =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[5, 5][5, 1]cpu"):
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:280 in f2, code: return x + y
add: "f32[5, 5][5, 1]cpu" = l_x_ + 3; l_x_ = None
return (add,)
compile 2: True
~~~~~~~~~~
图中断 (Graph Breaks)#
图中断是 torch.compile 中最核心的概念之一。它允许 torch.compile 通过中断编译、运行不支持的代码,然后恢复编译来处理任意 Python 代码。“图中断”这一术语源于 torch.compile 尝试捕获并优化 PyTorch 操作图的事实。当遇到不支持的 Python 代码时,该图必须被“中断”。图中断会导致优化机会的丧失,这可能是不理想的,但这比静默的不正确运行或崩溃要好得多。
让我们看一个数据依赖控制流的示例,以更好地理解图中断的工作原理。
def bar(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
opt_bar = torch.compile(bar)
inp1 = torch.ones(10)
inp2 = torch.ones(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
TRACED GRAPH
===== __compiled_fn_20_0111169a_08f1_4ffa_90f7_4506ed967699 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_a_ = L_a_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:
sum_1: "f32[][]cpu" = l_b_.sum(); l_b_ = None
lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
return (lt, x)
TRACED GRAPH
===== __compiled_fn_24_20d9edfc_4340_4756_9d6c_5e9a3aa40661 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_x_ = L_x_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul: "f32[10][1]cpu" = l_x_ * l_b_; l_x_ = l_b_ = None
return (mul,)
TRACED GRAPH
===== __compiled_fn_26_cb610501_517f_424b_a4f5_627b519e056c =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):
l_b_ = L_b_
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1
b: "f32[10][1]cpu" = l_b_ * -1; l_b_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul_1: "f32[10][1]cpu" = l_x_ * b; l_x_ = b = None
return (mul_1,)
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
0.5000])
第一次运行 bar 时,我们看到 torch.compile 跟踪了对应于以下代码的 2 个图(注意 b.sum() < 0 为 False):
x = a / (torch.abs(a) + 1); b.sum()return x * b
第二次运行 bar 时,我们走了 if 语句的另一个分支,得到了 1 个对应于代码 b = b * -1; return x * b 的已跟踪图。我们没有看到第二次输出 x = a / (torch.abs(a) + 1); b.sum() 的图,因为 torch.compile 从第一次运行中缓存了这个图并重用了它。
让我们通过示例研究 TorchDynamo 如何逐步执行 bar。如果 b.sum() < 0,那么 TorchDynamo 将运行图 1,让 Python 确定条件的结果,然后运行图 2。另一方面,如果 not b.sum() < 0,那么 TorchDynamo 将运行图 1,让 Python 确定条件的结果,然后运行图 3。
我们可以通过使用 torch._logging.set_logs(graph_breaks=True) 查看所有的图中断。
TRACED GRAPH
===== __compiled_fn_28_6c4c5807_990c_4a05_906e_4a07cbd14123 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_a_ = L_a_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:
sum_1: "f32[][]cpu" = l_b_.sum(); l_b_ = None
lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
return (lt, x)
TRACED GRAPH
===== __compiled_fn_32_661112d1_89a7_4b38_8e60_3d29f201dd38 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_x_ = L_x_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul: "f32[10][1]cpu" = l_x_ * l_b_; l_x_ = l_b_ = None
return (mul,)
TRACED GRAPH
===== __compiled_fn_34_cffb46aa_c480_4c80_80a1_0e698678b490 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):
l_b_ = L_b_
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1
b: "f32[10][1]cpu" = l_b_ * -1; l_b_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul_1: "f32[10][1]cpu" = l_x_ * b; l_x_ = b = None
return (mul_1,)
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
0.5000])
为了最大化加速效果,应当限制图中断。我们可以通过使用 fullgraph=True 强制 TorchDynamo 在遇到第一次图中断时抛出错误。
# Reset to clear the torch.compile cache
torch._dynamo.reset()
opt_bar_fullgraph = torch.compile(bar, fullgraph=True)
try:
opt_bar_fullgraph(torch.randn(10), torch.randn(10))
except:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 360, in <module>
opt_bar_fullgraph(torch.randn(10), torch.randn(10))
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1058, in compile_wrapper
raise e.with_traceback(None) from e.__cause__ # User compiler error
torch._dynamo.exc.Unsupported: Data-dependent branching
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
The branch condition involves a tensor computed as follows:
# File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 313, in bar, code: if b.sum() < 0:
lt = lt(sum_1, 0)
Hint: The branch condition uses a scalar integer tensor. Consider rewriting the computation to use plain Python ints (e.g. use int attributes instead of tensor buffers) so the condition becomes a shape guard instead of data-dependent branching.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch.cond` to express dynamic control flow.
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: https://meta-pytorch.cn/compile-graph-break-site/gb/gb0170.html
from user code:
File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 313, in bar
if b.sum() < 0:
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
在上面的示例中,我们可以通过用 torch.cond 替换 if 语句来绕过这个图中断。
from functorch.experimental.control_flow import cond
@torch.compile(fullgraph=True)
def bar_fixed(a, b):
x = a / (torch.abs(a) + 1)
def true_branch(y):
return y * -1
def false_branch(y):
# NOTE: torch.cond doesn't allow aliased outputs
return y.clone()
b = cond(b.sum() < 0, true_branch, false_branch, (b,))
return x * b
bar_fixed(inp1, inp2)
bar_fixed(inp1, -inp2)
TRACED GRAPH
===== __compiled_fn_37_e141e3ac_5cf7_4a73_861a_f26d281e34bc =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_a_ = L_a_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:373 in bar_fixed, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:382 in bar_fixed, code: b = cond(b.sum() < 0, true_branch, false_branch, (b,))
sum_1: "f32[][]cpu" = l_b_.sum()
lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(lt, cond_true_0, cond_false_0, (l_b_,)); lt = cond_true_0 = cond_false_0 = l_b_ = None
b: "f32[10][1]cpu" = cond[0]; cond = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:383 in bar_fixed, code: return x * b
mul: "f32[10][1]cpu" = x * b; x = b = None
return (mul,)
class cond_true_0(torch.nn.Module):
def forward(self, l_b_: "f32[10][1]cpu"):
l_b__1 = l_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:376 in true_branch, code: return y * -1
mul: "f32[10][1]cpu" = l_b__1 * -1; l_b__1 = None
return (mul,)
class cond_false_0(torch.nn.Module):
def forward(self, l_b_: "f32[10][1]cpu"):
l_b__1 = l_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:380 in false_branch, code: return y.clone()
clone: "f32[10][1]cpu" = l_b__1.clone(); l_b__1 = None
return (clone,)
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
0.5000])
为了序列化图或在不同(即无 Python)的环境中运行图,请考虑使用 torch.export(从 PyTorch 2.1+ 开始)。一个重要的限制是 torch.export 不支持图中断。请查阅 torch.export 教程 以了解有关 torch.export 的更多详情。
查看 torch.compile 编程模型中关于图中断的部分,获取如何绕过图中断的技巧。
故障排查#
torch.compile 未能加速你的模型?编译时间太长?代码重编译太频繁?在处理图中断时遇到困难?正在寻找如何最好地使用 torch.compile 的建议?或者你只是想了解更多关于 torch.compile 内部工作原理的信息?
请查阅 torch.compile 编程模型。
结论#
在本教程中,我们通过介绍基本用法、演示相比 Eager 模式的加速效果、与 TorchScript 进行比较以及简要描述图中断,介绍了 torch.compile。
若要获取真实模型上的端到端示例,请查阅我们的 torch.compile 端到端教程。
如需排查问题或深入了解如何将 torch.compile 应用于你的代码,请查阅 torch.compile 编程模型。
希望你试用一下 torch.compile!
脚本运行总耗时: (0 分钟 16.908 秒)