torch.compiler.allow_in_graph#
- torch.compiler.allow_in_graph(fn)[源代码]#
指示编译器前端(Dynamo)在遇到函数时跳过其符号内省,而是直接将其写入图。.
如果您正在使用
torch.compile()
(使用 backend=”inductor”(默认值)),或者torch.export.export()
,并尝试在所有跟踪过程中对 Python 函数进行黑盒处理,请不要使用此 API。相反,请创建一个自定义运算符(请参阅 PyTorch 自定义运算符登陆页面)。警告
如果您是典型的 torch.compile 用户(例如,您正在将 torch.compile 应用于模型以使其运行得更快),您可能不想使用此函数。
allow_in_graph()
是一个“易错点”,因为它跳过了负责进行安全检查(图中断、处理闭包等)的编译器前端(Dynamo)。不正确的使用会导致难以调试的静默错误。对于没有 allow_in_graph 装饰器的 Python 函数,torch.compile 会正常跟踪该函数。
allow_in_graph()
会改变这一点,使得前端不会跟踪函数内部,但编译器后端仍然会跟踪它。与自定义运算符将函数视为整个 torch.compile 堆栈中的黑盒不同,以下表格比较了这些机制。机制
前端(Dynamo)
后端(AOTAutograd+Inductor)
无装饰器
跟踪内部
跟踪内部
allow_in_graph
不透明可调用
跟踪内部
自定义运算符
不透明可调用
不透明可调用
allow_in_graph() 的一个常见用例是作为编译器前端的逃生舱:如果您知道函数相对于编译堆栈(AOTAutograd 和 Inductor)的下游组件而言是有效的,但存在阻止其正确符号内省函数的 Dynamo 错误(或者您的代码是用 C/C++ 编写的,因此无法用 Dynamo 进行内省),那么可以为该函数使用 allow_in_graph() 装饰器来绕过 Dynamo。
我们要求
fn
遵循以下限制。未能遵守将导致未定义行为。fn 的输入必须是 FX 图中的 Proxy-able 类型。有效类型包括:Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?] Tuple[Tensor?, …]/Tuple[int?, …]/Tuple[float?, …]/torch.dtype/torch.device
fn 的输出必须是 FX 图中的 Proxy-able 类型(参见上一条)。
fn 中使用的所有 Tensor 都必须直接作为 fn 的输入传入(而不是作为捕获的变量)。
- 参数
fn – 一个可调用对象,代表要包含在图中的函数。如果
fn
是可调用对象的列表或元组,则它会递归地将allow_in_graph()
应用于每个函数,并返回一个包含修改后函数的新列表或元组。
示例
torch.compiler.allow_in_graph(my_custom_function) @torch.compile(...) def fn(x): x = torch.add(x, 1) x = my_custom_function(x) x = torch.add(x, 1) return x fn(...)
将捕获一个包含
my_custom_function()
的单个图。