注意
转到末尾 下载完整的示例代码。
使用 torch.compiler.set_stance 进行动态编译控制#
作者: William Wen
torch.compiler.set_stance 是一个 torch.compiler API,它允许您在不重新应用 torch.compile 的情况下,改变模型在不同调用时的 torch.compile 行为。
本示例提供了一些关于如何使用 torch.compiler.set_stance 的例子。
先决条件#
torch >= 2.6
描述#
torch.compile.set_stance 可以用作装饰器、上下文管理器或原始函数,来改变模型在不同调用时的 torch.compile 行为。
在下面的示例中,"force_eager" 状态会忽略所有 torch.compile 指令。
import torch
@torch.compile
def foo(x):
if torch.compiler.is_compiling():
# torch.compile is active
return x + 1
else:
# torch.compile is not active
return x - 1
inp = torch.zeros(3)
print(foo(inp)) # compiled, prints 1
tensor([1., 1., 1.])
示例装饰器用法
@torch.compiler.set_stance("force_eager")
def bar(x):
# force disable the compiler
return foo(x)
print(bar(inp)) # not compiled, prints -1
tensor([-1., -1., -1.])
示例上下文管理器用法
with torch.compiler.set_stance("force_eager"):
print(foo(inp)) # not compiled, prints -1
tensor([-1., -1., -1.])
示例原始函数用法
torch.compiler.set_stance("force_eager")
print(foo(inp)) # not compiled, prints -1
torch.compiler.set_stance("default")
print(foo(inp)) # compiled, prints 1
tensor([-1., -1., -1.])
tensor([1., 1., 1.])
torch.compile 状态只能在任何 torch.compile 区域的外部更改。否则会引发错误。
@torch.compile
def baz(x):
# error!
with torch.compiler.set_stance("force_eager"):
return x + 1
try:
baz(inp)
except Exception as e:
print(e)
@torch.compiler.set_stance("force_eager")
def inner(x):
return x + 1
@torch.compile
def outer(x):
# error!
return inner(x)
try:
outer(inp)
except Exception as e:
print(e)
Attempt to trace forbidden callable <function set_stance at 0x7f66ad509750>
from user code:
File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 85, in baz
with torch.compiler.set_stance("force_eager"):
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Attempt to trace forbidden callable <function inner at 0x7f66c86d3760>
from user code:
File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 103, in outer
return inner(x)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
- 其他状态包括:
"default":默认状态,用于正常编译。"eager_on_recompile":在需要重新编译时,以立即执行模式运行代码。如果存在对输入有效的缓存编译代码,则仍会使用它。"fail_on_recompile":在重新编译函数时引发错误。
有关更多状态和选项,请参阅 torch.compiler.set_stance 的 文档页面。未来也可能添加更多状态/选项。
示例#
防止重新编译#
有些模型不期望任何重新编译——例如,您可能有输入始终具有相同的形状。由于重新编译可能成本高昂,我们可能希望在尝试重新编译时报错,以便检测和修复重新编译的情况。"fail_on_recompilation" 状态可用于此目的。
@torch.compile
def my_big_model(x):
return torch.relu(x)
# first compilation
my_big_model(torch.randn(3))
with torch.compiler.set_stance("fail_on_recompile"):
my_big_model(torch.randn(3)) # no recompilation - OK
try:
my_big_model(torch.randn(4)) # recompilation - error
except Exception as e:
print(e)
Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py', function name: 'my_big_model', line number: 0
如果报错过于 disruptive,我们可以改用 "eager_on_recompile",它将导致 torch.compile 回退到立即执行模式而不是报错。如果预计重新编译不会频繁发生,但一旦需要,我们宁愿承担立即执行的成本而不是重新编译的成本,那么这可能很有用。
@torch.compile
def my_huge_model(x):
if torch.compiler.is_compiling():
return x + 1
else:
return x - 1
# first compilation
print(my_huge_model(torch.zeros(3))) # 1
with torch.compiler.set_stance("eager_on_recompile"):
print(my_huge_model(torch.zeros(3))) # 1
print(my_huge_model(torch.zeros(4))) # -1
print(my_huge_model(torch.zeros(3))) # 1
tensor([1., 1., 1.])
tensor([1., 1., 1.])
tensor([-1., -1., -1., -1.])
tensor([1., 1., 1.])
衡量性能提升#
torch.compiler.set_stance 可用于比较立即执行模式和编译模式的性能,而无需定义单独的立即执行模型。
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
@torch.compile
def my_gigantic_model(x, y):
x = x @ y
x = x @ y
x = x @ y
return x
inps = torch.randn(5, 5), torch.randn(5, 5)
with torch.compiler.set_stance("force_eager"):
print("eager:", timed(lambda: my_gigantic_model(*inps))[1])
# warmups
for _ in range(3):
my_gigantic_model(*inps)
print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])
eager: 0.00026208001375198364
compiled: 0.00012691199779510498
更早崩溃#
在使用 "force_eager" 状态执行一次立即执行迭代,然后再执行一次编译迭代,可以帮助我们在尝试非常耗时的编译之前捕获与 torch.compile 无关的错误。
@torch.compile
def my_humongous_model(x):
return torch.sin(x, x)
try:
with torch.compiler.set_stance("force_eager"):
print(my_humongous_model(torch.randn(3)))
# this call to the compiled model won't run
print(my_humongous_model(torch.randn(3)))
except Exception as e:
print(e)
sin() takes 1 positional argument but 2 were given
结论#
在本示例中,我们学习了如何使用 torch.compiler.set_stance API,在不重新应用 torch.compile 的情况下,修改模型在不同调用时的行为。本示例演示了如何将 torch.compiler.set_stance 用作装饰器、上下文管理器或原始函数,来控制 force_eager、default、eager_on_recompile 和 fail_on_recompile 等编译状态。
更多信息,请参阅:torch.compiler.set_stance API 文档。
脚本总运行时间: (0 分钟 10.660 秒)