评价此页
torch.compile">

torch.compile 简介#

创建于: 2023年3月15日 | 最后更新: 2025年10月15日 | 最后验证: 2024年11月5日

作者: William Wen

torch.compile 是加速 PyTorch 代码的新方法!torch.compile 通过将 PyTorch 代码 JIT 编译成优化后的内核,让 PyTorch 代码运行得更快,同时只需进行最少的代码更改。

torch.compile 通过跟踪您的 Python 代码并查找 PyTorch 操作来完成此操作。难以跟踪的代码将导致 **图中断 (graph break)**,这会丢失优化机会,而不是导致错误或静默不正确。

torch.compile 在 PyTorch 2.0 及更高版本中可用。

本简介涵盖了 torch.compile 的基本用法,并演示了 torch.compile 相对于我们之前的 PyTorch 编译器解决方案 TorchScript 的优势。

有关真实模型的端到端示例,请查看我们的 torch.compile 端到端教程

要排查问题并更深入地了解如何将 torch.compile 应用于您的代码,请查看 torch.compile 编程模型

内容

本教程所需的 pip 依赖项

  • torch >= 2.0

  • numpy

  • scipy

系统要求 - C++ 编译器,例如 g++ - Python 开发包 (python-devel/python-dev)

基本用法#

在本教程中,我们启用了一些日志记录,以帮助我们了解 torch.compile 在底层做了什么。以下代码将打印出 torch.compile 跟踪的 PyTorch 操作。

import torch


torch._logging.set_logs(graph_code=True)

torch.compile 是一个接受任意 Python 函数的装饰器。

def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(3, 3), torch.randn(3, 3)))


@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


print(opt_foo2(torch.randn(3, 3), torch.randn(3, 3)))
TRACED GRAPH
 ===== __compiled_fn_1_57703c6c_17e9_44be_adf9_87ae8a7f015f =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
        l_x_ = L_x_
        l_y_ = L_y_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:74 in foo, code: a = torch.sin(x)
        a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_);  l_x_ = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:75 in foo, code: b = torch.cos(y)
        b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_);  l_y_ = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:76 in foo, code: return a + b
        add: "f32[3, 3][3, 1]cpu" = a + b;  a = b = None
        return (add,)


tensor([[ 0.0663,  1.8726,  1.0057],
        [-0.3487,  0.3188,  0.9310],
        [ 1.8560,  0.4513, -0.4614]])
TRACED GRAPH
 ===== __compiled_fn_3_12712180_e493_4bc2_8b8e_dcdfd783faaa =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
        l_x_ = L_x_
        l_y_ = L_y_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:85 in opt_foo2, code: a = torch.sin(x)
        a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_);  l_x_ = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:86 in opt_foo2, code: b = torch.cos(y)
        b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_);  l_y_ = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:87 in opt_foo2, code: return a + b
        add: "f32[3, 3][3, 1]cpu" = a + b;  a = b = None
        return (add,)


tensor([[ 0.2038,  0.5530,  0.2229],
        [-0.3382,  0.5160, -0.0161],
        [ 1.7310,  1.3559,  1.2261]])

torch.compile 是递归应用的,因此顶级编译函数内的嵌套函数调用也将被编译。

def inner(x):
    return torch.sin(x)


@torch.compile
def outer(x, y):
    a = inner(x)
    b = torch.cos(y)
    return a + b


print(outer(torch.randn(3, 3), torch.randn(3, 3)))
TRACED GRAPH
 ===== __compiled_fn_5_03c189a8_83d7_41cc_a42b_e8e8d534d682 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
        l_x_ = L_x_
        l_y_ = L_y_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:98 in inner, code: return torch.sin(x)
        a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_);  l_x_ = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:104 in outer, code: b = torch.cos(y)
        b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_);  l_y_ = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:105 in outer, code: return a + b
        add: "f32[3, 3][3, 1]cpu" = a + b;  a = b = None
        return (add,)


tensor([[ 1.2845, -0.0892, -0.2115],
        [ 1.3537, -0.0816, -0.0732],
        [-0.3591,  1.5748,  0.7948]])

我们还可以通过调用其 .compile() 方法或直接 torch.compile-ing 模块来优化 torch.nn.Module 实例。这等同于 torch.compile-ing 模块的 __call__ 方法(该方法间接调用 forward)。

t = torch.randn(10, 100)


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(3, 3)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))


mod1 = MyModule()
mod1.compile()
print(mod1(torch.randn(3, 3)))

mod2 = MyModule()
mod2 = torch.compile(mod2)
print(mod2(torch.randn(3, 3)))
TRACED GRAPH
 ===== __compiled_fn_7_d919aa2b_ce68_443d_ab75_c1f3ad8968a4 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_self_modules_lin_parameters_weight_: "f32[3, 3][3, 1]cpu", L_self_modules_lin_parameters_bias_: "f32[3][1]cpu", L_x_: "f32[3, 3][3, 1]cpu"):
        l_self_modules_lin_parameters_weight_ = L_self_modules_lin_parameters_weight_
        l_self_modules_lin_parameters_bias_ = L_self_modules_lin_parameters_bias_
        l_x_ = L_x_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:126 in forward, code: return torch.nn.functional.relu(self.lin(x))
        linear: "f32[3, 3][3, 1]cpu" = torch._C._nn.linear(l_x_, l_self_modules_lin_parameters_weight_, l_self_modules_lin_parameters_bias_);  l_x_ = l_self_modules_lin_parameters_weight_ = l_self_modules_lin_parameters_bias_ = None
        relu: "f32[3, 3][3, 1]cpu" = torch.nn.functional.relu(linear);  linear = None
        return (relu,)


tensor([[0.4863, 0.2575, 0.5411],
        [0.1428, 0.0000, 0.3762],
        [0.4444, 0.5583, 0.7902]], grad_fn=<CompiledFunctionBackward>)
tensor([[0.0000, 0.0000, 1.4330],
        [0.0000, 0.0000, 0.0536],
        [0.0000, 0.0000, 0.1456]], grad_fn=<CompiledFunctionBackward>)

演示加速效果#

现在让我们演示 torch.compile 如何加速一个简单的 PyTorch 示例。有关更复杂模型的演示,请参阅我们的 torch.compile 端到端教程

def foo3(x):
    y = x + 1
    z = torch.nn.functional.relu(y)
    u = z * 2
    return u


opt_foo3 = torch.compile(foo3)


# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1024


inp = torch.randn(4096, 4096).cuda()
print("compile:", timed(lambda: opt_foo3(inp))[1])
print("eager:", timed(lambda: foo3(inp))[1])
TRACED GRAPH
 ===== __compiled_fn_9_08a72ca3_c6ee_45c6_a198_0e8c99e7092d =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[4096, 4096][4096, 1]cuda:0"):
        l_x_ = L_x_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:147 in foo3, code: y = x + 1
        y: "f32[4096, 4096][4096, 1]cuda:0" = l_x_ + 1;  l_x_ = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:148 in foo3, code: z = torch.nn.functional.relu(y)
        z: "f32[4096, 4096][4096, 1]cuda:0" = torch.nn.functional.relu(y);  y = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:149 in foo3, code: u = z * 2
        u: "f32[4096, 4096][4096, 1]cuda:0" = z * 2;  z = None
        return (u,)


compile: 0.40412646532058716
eager: 0.02964000031352043

请注意,与 eager 模式相比,torch.compile 的完成时间似乎要长得多。这是因为 torch.compile 在前几次执行时需要额外的时间来编译模型。torch.compile 在可能的情况下重用编译后的代码,因此如果我们再运行几次优化后的模型,与 eager 模式相比,我们应该会看到显著的改进。

# turn off logging for now to prevent spam
torch._logging.set_logs(graph_code=False)

eager_times = []
for i in range(10):
    _, eager_time = timed(lambda: foo3(inp))
    eager_times.append(eager_time)
    print(f"eager time {i}: {eager_time}")
print("~" * 10)

compile_times = []
for i in range(10):
    _, compile_time = timed(lambda: opt_foo3(inp))
    compile_times.append(compile_time)
    print(f"compile time {i}: {compile_time}")
print("~" * 10)

import numpy as np

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
    f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)
eager time 0: 0.00088900001719594
eager time 1: 0.0008459999808110297
eager time 2: 0.0008459999808110297
eager time 3: 0.0008479999960400164
eager time 4: 0.000846999988425523
eager time 5: 0.0008420000085607171
eager time 6: 0.0008420000085607171
eager time 7: 0.0008509375038556755
eager time 8: 0.0008399999933317304
eager time 9: 0.0008440000237897038
~~~~~~~~~~
compile time 0: 0.0005019999807700515
compile time 1: 0.0003699999942909926
compile time 2: 0.00036100001307204366
compile time 3: 0.0003539999888744205
compile time 4: 0.00035700001171790063
compile time 5: 0.0003530000103637576
compile time 6: 0.0003530000103637576
compile time 7: 0.0003499999875202775
compile time 8: 0.0003539999888744205
compile time 9: 0.0003530000103637576
~~~~~~~~~~
(eval) eager median: 0.0008459999808110297, compile median: 0.0003539999888744205, speedup: 2.389830529376495x
~~~~~~~~~~

事实上,我们可以看到,使用 torch.compile 运行我们的模型可以显著加速。加速主要来自于减少 Python 开销和 GPU 读写,因此观察到的加速效果可能会因模型架构和批次大小等因素而异。例如,如果模型的架构很简单,数据量很大,那么瓶颈将是 GPU 计算,观察到的加速效果可能会不那么显著。

要查看真实模型的加速效果,请查看我们的 torch.compile 端到端教程

相比 TorchScript 的优势#

为什么我们应该使用 torch.compile 而不是 TorchScript?主要而言,torch.compile 的优势在于其能够以最少的代码更改处理任意 Python 代码。

与 TorchScript 相比,TorchScript 具有跟踪模式 (torch.jit.trace) 和脚本模式 (torch.jit.script)。跟踪模式容易出现静默不正确,而脚本模式需要大量的代码更改,并且在遇到不受支持的 Python 代码时会引发错误。

例如,TorchScript 跟踪在依赖于数据的控制流(下面的 if x.sum() < 0: 行)上会静默失败,因为只跟踪了实际的控制流路径。相比之下,torch.compile 能够正确处理它。

def f1(x, y):
    if x.sum() < 0:
        return -y
    return y


# Test that `fn1` and `fn2` return the same result, given the same arguments `args`.
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)


inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)
/var/lib/workspace/intermediate_source/torch_compile_tutorial.py:239: TracerWarning:

Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

traced 1, 1: True
traced 1, 2: False
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~

TorchScript 脚本模式可以处理依赖于数据的控制流,但可能需要重大的代码更改,并且在使用不受支持的 Python 时会引发错误。

在下面的示例中,我们忘记了 TorchScript 的类型注解,并且收到了一个 TorchScript 错误,因为参数 y(一个 int)的输入类型与默认参数类型 torch.Tensor 不匹配。相比之下,torch.compile 在不需要任何类型注解的情况下工作。

import traceback as tb

torch._logging.set_logs(graph_code=True)


def f2(x, y):
    return x + y


inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()

compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 288, in <module>
    script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor
TRACED GRAPH
 ===== __compiled_fn_18_60f88fab_6a3d_4dcc_a2ea_16a1899bfb1f =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[5, 5][5, 1]cpu"):
        l_x_ = L_x_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:280 in f2, code: return x + y
        add: "f32[5, 5][5, 1]cpu" = l_x_ + 3;  l_x_ = None
        return (add,)


compile 2: True
~~~~~~~~~~

图中断 (Graph Breaks)#

图中断是 torch.compile 中最基本概念之一。它通过中断编译、运行不受支持的代码,然后恢复编译,从而使 torch.compile 能够处理任意 Python 代码。术语“图中断”来自于 torch.compile 尝试捕获和优化 PyTorch 操作图的事实。当遇到不受支持的 Python 代码时,该图必须被“中断”。图中断会导致优化机会的丢失,这仍然可能是不受欢迎的,但这比静默不正确或硬崩溃要好。

让我们通过一个依赖于数据的控制流示例来更好地了解图中断的工作原理。

def bar(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b


opt_bar = torch.compile(bar)
inp1 = torch.ones(10)
inp2 = torch.ones(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
TRACED GRAPH
 ===== __compiled_fn_20_d5309909_d209_4382_9b82_0ba74ced4ca8 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
        l_a_ = L_a_
        l_b_ = L_b_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)
        abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
        add: "f32[10][1]cpu" = abs_1 + 1;  abs_1 = None
        x: "f32[10][1]cpu" = l_a_ / add;  l_a_ = add = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:
        sum_1: "f32[][]cpu" = l_b_.sum();  l_b_ = None
        lt: "b8[][]cpu" = sum_1 < 0;  sum_1 = None
        return (lt, x)


TRACED GRAPH
 ===== __compiled_fn_24_24e667b5_a8e5_442d_b94a_a878f1114d23 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
        l_x_ = L_x_
        l_b_ = L_b_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
        mul: "f32[10][1]cpu" = l_x_ * l_b_;  l_x_ = l_b_ = None
        return (mul,)


TRACED GRAPH
 ===== __compiled_fn_26_d1830df0_39a5_4379_96f3_af6c112110cd =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):
        l_b_ = L_b_
        l_x_ = L_x_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1
        b: "f32[10][1]cpu" = l_b_ * -1;  l_b_ = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
        mul_1: "f32[10][1]cpu" = l_x_ * b;  l_x_ = b = None
        return (mul_1,)



tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000])

第一次运行 bar 时,我们看到 torch.compile 跟踪了 2 个图,对应于以下代码(请注意 b.sum() < 0 为 False)

  1. x = a / (torch.abs(a) + 1); b.sum()

  2. return x * b

第二次运行 bar 时,我们走到了 if 语句的另一条分支,并得到了 1 个跟踪的图,对应于代码 b = b * -1; return x * b。第二次运行时我们没有看到 x = a / (torch.abs(a) + 1) 的图输出,因为 torch.compile 从第一次运行缓存了该图并重用了它。

让我们通过示例来研究 TorchDynamo 如何逐步执行 bar。如果 b.sum() < 0,则 TorchDynamo 将运行图 1,让 Python 确定条件的结果,然后运行图 2。另一方面,如果 not b.sum() < 0,则 TorchDynamo 将运行图 1,让 Python 确定条件的结果,然后运行图 3。

我们可以通过使用 torch._logging.set_logs(graph_breaks=True) 来查看所有图中断。

# Reset to clear the torch.compile cache
torch._dynamo.reset()
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
TRACED GRAPH
 ===== __compiled_fn_28_e75c1c8c_4795_4a16_8d6f_90d489a9e78e =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
        l_a_ = L_a_
        l_b_ = L_b_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)
        abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
        add: "f32[10][1]cpu" = abs_1 + 1;  abs_1 = None
        x: "f32[10][1]cpu" = l_a_ / add;  l_a_ = add = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:
        sum_1: "f32[][]cpu" = l_b_.sum();  l_b_ = None
        lt: "b8[][]cpu" = sum_1 < 0;  sum_1 = None
        return (lt, x)


TRACED GRAPH
 ===== __compiled_fn_32_e26b0760_f8cc_414d_a852_6092ac007ca7 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
        l_x_ = L_x_
        l_b_ = L_b_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
        mul: "f32[10][1]cpu" = l_x_ * l_b_;  l_x_ = l_b_ = None
        return (mul,)


TRACED GRAPH
 ===== __compiled_fn_34_2b406644_b833_40a0_96ec_c1f387d13c7f =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):
        l_b_ = L_b_
        l_x_ = L_x_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1
        b: "f32[10][1]cpu" = l_b_ * -1;  l_b_ = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
        mul_1: "f32[10][1]cpu" = l_x_ * b;  l_x_ = b = None
        return (mul_1,)



tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000])

为了最大化加速效果,应该限制图中断。我们可以通过使用 fullgraph=True 来强制 TorchDynamo 在遇到第一个图中断时引发错误。

# Reset to clear the torch.compile cache
torch._dynamo.reset()

opt_bar_fullgraph = torch.compile(bar, fullgraph=True)
try:
    opt_bar_fullgraph(torch.randn(10), torch.randn(10))
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 360, in <module>
    opt_bar_fullgraph(torch.randn(10), torch.randn(10))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 841, in compile_wrapper
    raise e.with_traceback(None) from e.__cause__  # User compiler error
torch._dynamo.exc.Unsupported: Data-dependent branching
  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
  Hint: Use `torch.cond` to express dynamic control flow.

  Developer debug context: attempted to jump with TensorVariable()

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html

from user code:
   File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 313, in bar
    if b.sum() < 0:

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

在我们上面的示例中,我们可以通过将 if 语句替换为 torch.cond 来解决此图中断问题。

from functorch.experimental.control_flow import cond


@torch.compile(fullgraph=True)
def bar_fixed(a, b):
    x = a / (torch.abs(a) + 1)

    def true_branch(y):
        return y * -1

    def false_branch(y):
        # NOTE: torch.cond doesn't allow aliased outputs
        return y.clone()

    x = cond(b.sum() < 0, true_branch, false_branch, (b,))
    return x * b


bar_fixed(inp1, inp2)
bar_fixed(inp1, -inp2)
TRACED GRAPH
 ===== __compiled_fn_37_6c5f108a_d951_495b_a538_024359c8fc5a =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
        l_a_ = L_a_
        l_b_ = L_b_

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:373 in bar_fixed, code: x = a / (torch.abs(a) + 1)
        abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
        add: "f32[10][1]cpu" = abs_1 + 1;  abs_1 = None
        x: "f32[10][1]cpu" = l_a_ / add;  l_a_ = add = x = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:382 in bar_fixed, code: x = cond(b.sum() < 0, true_branch, false_branch, (b,))
        sum_1: "f32[][]cpu" = l_b_.sum()
        lt: "b8[][]cpu" = sum_1 < 0;  sum_1 = None

         # File: /usr/local/lib/python3.10/dist-packages/torch/_higher_order_ops/cond.py:186 in cond, code: return cond_op(pred, true_fn, false_fn, operands)
        cond_true_0 = self.cond_true_0
        cond_false_0 = self.cond_false_0
        cond = torch.ops.higher_order.cond(lt, cond_true_0, cond_false_0, (l_b_,));  lt = cond_true_0 = cond_false_0 = None
        x_1: "f32[10][1]cpu" = cond[0];  cond = None

         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:383 in bar_fixed, code: return x * b
        mul: "f32[10][1]cpu" = x_1 * l_b_;  x_1 = l_b_ = None
        return (mul,)

    class cond_true_0(torch.nn.Module):
        def forward(self, l_b_: "f32[10][1]cpu"):
            l_b__1 = l_b_

             # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:376 in true_branch, code: return y * -1
            mul: "f32[10][1]cpu" = l_b__1 * -1;  l_b__1 = None
            return (mul,)

    class cond_false_0(torch.nn.Module):
        def forward(self, l_b_: "f32[10][1]cpu"):
            l_b__1 = l_b_

             # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:380 in false_branch, code: return y.clone()
            clone: "f32[10][1]cpu" = l_b__1.clone();  l_b__1 = None
            return (clone,)



tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.])

为了序列化图或在不同的(例如,无 Python 的)环境中运行图,请考虑改用 torch.export(从 PyTorch 2.1+ 开始)。一个重要的限制是 torch.export 不支持图中断。请参阅 torch.export 教程 以获取有关 torch.export 的更多详细信息。

请查看我们在 torch.compile 编程模型中关于图中断的部分,以获取有关如何解决图中断的技巧。

故障排除#

torch.compile 未能加速您的模型?编译时间过长?您的代码是否过度重新编译?您在处理图中断方面遇到困难?您是否正在寻找如何最好地使用 torch.compile 的技巧?或者您只是想更多地了解 torch.compile 的内部工作原理?

请查看 torch.compile 编程模型

结论#

在本教程中,我们通过介绍基本用法、演示与 eager 模式相比的加速效果、与 TorchScript 进行比较以及简要描述图中断,介绍了 torch.compile

有关真实模型的端到端示例,请查看我们的 torch.compile 端到端教程

要排查问题并更深入地了解如何将 torch.compile 应用于您的代码,请查看 torch.compile 编程模型

希望您会尝试 torch.compile

脚本总运行时间: (0 分钟 16.527 秒)