处理重新编译#
创建时间: 2025年7月29日 | 最后更新时间: 2025年7月29日
为了确保 torch.compile
的正确性,必须进行重新编译,但这可能导致编译时间显著增加。因此,在保证正确性的前提下最大程度地减少重新编译对于缩短编译时间至关重要。
您可以使用 tlparse 或 TORCH_LOGS=recompiles
来查看重新编译及其原因。
是否启用了动态形状?#
在下面的示例中,我们因为形状不匹配而进行重新编译。
@torch.compile
def fn(x):
return x + 1
fn(torch.ones(3))
fn(torch.ones(4))
Recompiling function fn in /tmp/ipykernel_990/2479206322.py:1
triggered by the following guard failure(s):
- 0/0: tensor 'x' size mismatch at index 0. expected 3, actual 4
tensor([2., 2., 2., 2.])
请确保 torch.compile
的 dynamic
选项未设置为 False
。默认选项 dynamic=None
将仅在首次编译后尝试使用动态形状。您可以将 dynamic
设置为 True
,以便尽可能地提前进行动态编译。
@torch.compile(dynamic=True)
def gn(x):
return x + 1
gn(torch.ones(3))
gn(torch.ones(4))
tensor([2., 2., 2., 2.])
有关动态形状的更多信息,包括如何处理由于动态形状引起的错误/重新编译,请参阅 动态形状手册。
将常量封装到张量中#
默认情况下,int
/ float
变量被视为常量,并基于其精确值进行保护。在下面的示例中,每次函数调用都会导致一次重新编译。
@torch.compile
def fn(x, c):
return x + c
for i in range(5):
fn(torch.ones(i), 0.5 + i)
Recompiling function fn in /tmp/ipykernel_990/3647755280.py:1
triggered by the following guard failure(s):
- 2/0: c == 0.5 # return x + c # mp/ipykernel_990/3647755280.py:3 in fn
Recompiling function fn in /tmp/ipykernel_990/3647755280.py:1
triggered by the following guard failure(s):
- 2/1: tensor 'x' size mismatch at index 0. expected 1, actual 2
- 2/0: c == 0.5 # return x + c # mp/ipykernel_990/3647755280.py:3 in fn
特别是,对于学习率调度器,使用常量初始化可能会导致重新编译。
mod = torch.nn.Linear(3, 3)
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)
@torch.compile
def gn(inp):
opt.zero_grad(True)
out = mod(inp).sum()
out.backward()
opt.step()
sched.step()
for i in range(5):
gn(torch.ones(3, 3))
Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
Recompiling function step in /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/adam.py:213
triggered by the following guard failure(s):
- 7/0: self.param_groups[0]['lr'] == 0.01 # for group in self.param_groups: # optim/adam.py:228 in step
在以上两个示例中,我们可以将 float
变量封装到张量中,以防止重新编译。
# first example
for i in range(5):
fn(torch.ones(i), torch.tensor(0.5 + i))
# second example
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))
for i in range(5):
gn(torch.ones(3, 3))
Recompiling function fn in /tmp/ipykernel_990/3647755280.py:1
triggered by the following guard failure(s):
- 0/0: tensor 'x' size mismatch at index 0. expected 0, actual 1
Recompiling function fn in /tmp/ipykernel_990/3647755280.py:1
triggered by the following guard failure(s):
- 0/1: tensor 'x' size mismatch at index 0. expected 1, actual 2
- 0/0: tensor 'x' size mismatch at index 0. expected 0, actual 2
更改缓存大小限制#
函数可以被重新编译的次数是有限制的,这由 torch._dynamo.config.cache_size_limit
和 torch._dynamo.config.accumulated_cache_size_limit
决定(这两个值之间的确切区别在 torch/_dynamo/cache_size.py
中有详细说明)。如果达到 Dynamo 缓存限制,那么所有未来的编译尝试**都将导致函数被跳过(即时运行)**。如果保护条件通过,Dynamo 仍会尝试使用先前编译的字节码来进行后续函数调用。请注意,在达到重新编译限制的情况下,**所有嵌套函数调用都将被跳过**(Dynamo 会尝试使用先前编译的字节码来处理嵌套函数)。Dynamo 还会发出警告,其中包含受影响的函数以及触发了哪个限制。在下面的示例中,每次函数调用都会导致一次重新编译尝试。当达到缓存大小限制(默认为 8)时,我们会停止尝试重新编译。(注意,为了演示的目的,我们已将 dynamic
设置为 False
以强制每次都进行重新编译)。
@torch.compile(dynamic=False)
def fn(x):
return x + 1
for i in range(1, 10):
# recompile every time due to dynamic=False
fn(torch.ones(i))
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 2
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 3
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 3
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 4
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 4
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 4
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 5
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 5
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 5
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 5
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 6
- 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 6
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 6
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 6
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 6
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/5: tensor 'x' size mismatch at index 0. expected 6, actual 7
- 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 7
- 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 7
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 7
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 7
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 7
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/6: tensor 'x' size mismatch at index 0. expected 7, actual 8
- 8/5: tensor 'x' size mismatch at index 0. expected 6, actual 8
- 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 8
- 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 8
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 8
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 8
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 8
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/7: tensor 'x' size mismatch at index 0. expected 8, actual 9
- 8/6: tensor 'x' size mismatch at index 0. expected 7, actual 9
- 8/5: tensor 'x' size mismatch at index 0. expected 6, actual 9
- 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 9
- 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 9
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 9
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 9
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 9
torch._dynamo hit config.recompile_limit (8)
function: 'fn' (/tmp/ipykernel_990/3054308037.py:1)
last reason: 8/7: tensor 'x' size mismatch at index 0. expected 8, actual 9
To log all recompilation reasons, use TORCH_LOGS="recompiles".
To diagnose recompilation issues, see https://pytorch.ac.cn/docs/stable/torch.compiler_troubleshooting.html
如果您知道重新编译次数有一个合理的常量上限,您可以提高缓存大小限制。如果重新编译的成本超过了编译的好处,那么您可以考虑降低缓存大小限制。
torch._dynamo.config.cache_size_limit = 16
@torch.compile(dynamic=False)
def gn(x):
return x + 1
for i in range(1, 10):
gn(torch.ones(i))
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 2
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 3
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 3
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 4
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 4
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 4
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 5
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 5
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 5
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 5
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 6
- 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 6
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 6
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 6
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 6
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/5: tensor 'x' size mismatch at index 0. expected 6, actual 7
- 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 7
- 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 7
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 7
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 7
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 7
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/6: tensor 'x' size mismatch at index 0. expected 7, actual 8
- 9/5: tensor 'x' size mismatch at index 0. expected 6, actual 8
- 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 8
- 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 8
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 8
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 8
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 8
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/7: tensor 'x' size mismatch at index 0. expected 8, actual 9
- 9/6: tensor 'x' size mismatch at index 0. expected 7, actual 9
- 9/5: tensor 'x' size mismatch at index 0. expected 6, actual 9
- 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 9
- 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 9
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 9
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 9
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 9
图中断以降低重新编译成本#
如果一个大型图正在重新编译并导致高编译时间,您可以故意引入一个图中断来降低重新编译成本,但代价是引入性能损失。
def very_large_function(x):
return x + 1
@torch.compile(dynamic=False)
def fn(x, c):
y = very_large_function(x) # recompiled every time
return y + c
for i in range(1, 5):
fn(torch.ones(3), i)
@torch.compile(dynamic=False)
def gn(x, c):
y = very_large_function(x) # compiled only once
torch._dynamo.graph_break()
return y + c # recompiled every time
for i in range(1, 5):
gn(torch.ones(3), i)
Recompiling function fn in /tmp/ipykernel_990/2876112129.py:4
triggered by the following guard failure(s):
- 10/0: c == 1 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
Recompiling function fn in /tmp/ipykernel_990/2876112129.py:4
triggered by the following guard failure(s):
- 10/1: c == 2 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
- 10/0: c == 1 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
Recompiling function fn in /tmp/ipykernel_990/2876112129.py:4
triggered by the following guard failure(s):
- 10/2: c == 3 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
- 10/1: c == 2 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
- 10/0: c == 1 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
Recompiling function torch_dynamo_resume_in_gn_at_15 in /tmp/ipykernel_990/2876112129.py:15
triggered by the following guard failure(s):
- 12/0: c == 1 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15
Recompiling function torch_dynamo_resume_in_gn_at_15 in /tmp/ipykernel_990/2876112129.py:15
triggered by the following guard failure(s):
- 12/1: c == 2 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15
- 12/0: c == 1 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15
Recompiling function torch_dynamo_resume_in_gn_at_15 in /tmp/ipykernel_990/2876112129.py:15
triggered by the following guard failure(s):
- 12/2: c == 3 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15
- 12/1: c == 2 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15
- 12/0: c == 1 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15