torch.export 编程模型#
创建时间:2024 年 12 月 18 日 | 最后更新时间:2025 年 7 月 16 日
本文档旨在解释 torch.export.export() 的行为和能力。它旨在帮助您直观地理解 torch.export.export() 如何处理代码。
追踪基础知识#
torch.export.export() 通过在“示例”输入上追踪其执行并记录追踪路径上观察到的 PyTorch 操作和条件来捕获表示您模型的图。然后,只要输入满足相同的条件,就可以使用不同的输入运行此图。
torch.export.export() 的基本输出是单个 PyTorch 操作图,以及相关的元数据。该输出的确切格式在 export IR 规范 中介绍。
严格追踪与非严格追踪#
torch.export.export() 提供两种追踪模式。
在非严格模式下,我们使用标准的 Python 解释器进行追踪。您的代码会像在 eager 模式下一样执行;唯一的区别是所有 Tensor 都被替换为 Fake Tensor,这些 Fake Tensor 具有形状和其他元数据,但没有数据,并被封装在 Proxy 对象 中,这些 Proxy 对象会将所有对它们的运算记录到一个图中。我们还捕获 Tensor 形状的条件,这些条件会保护生成代码的正确性。
在严格模式下,我们首先使用 TorchDynamo(一个 Python 字节码分析引擎)进行追踪。TorchDynamo 实际上不执行您的 Python 代码。相反,它会对其进行符号分析并根据结果构建图。一方面,这种分析允许 torch.export.export() 提供对 Python 级别安全性的额外保证(除了在非严格模式下捕获 Tensor 形状的条件之外)。另一方面,并非所有 Python 特性都支持这种分析。
尽管目前默认的追踪模式是严格模式,但我们强烈建议使用非严格模式,该模式很快将成为默认模式。对于大多数模型而言,Tensor 形状的条件足以保证正确性,并且对 Python 级别安全性的额外保证没有影响;同时,在 TorchDynamo 中遇到不支持的 Python 特性的可能性会带来不必要的风险。
在本档的其余部分,我们假设我们正在以 非严格模式 进行追踪;特别是,我们假设所有 Python 特性都得到支持。
值:静态与动态#
理解 torch.export.export() 行为的一个关键概念是静态值与动态值之间的区别。
静态值#
静态值是在导出时固定的,在导出程序执行之间不会改变的值。当在追踪过程中遇到该值时,我们将其视为常量并将其硬编码到图中。
当执行一个操作(例如 x + y)且所有输入都是静态的时,该操作的输出将被直接硬编码到图中,并且该操作不会显示(即它会被“常量折叠”)。
当一个值被硬编码到图中时,我们称该图已被专门化到该值。例如
import torch
class MyMod(torch.nn.Module):
def forward(self, x, y):
z = y + 7
return x + z
m = torch.export.export(MyMod(), (torch.randn(1), 3))
print(m.graph_module.code)
"""
def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(arg0_1, 10); arg0_1 = None
return (add,)
"""
在这里,我们将 3 作为 y 的追踪值;它被视为一个静态值,并被加到 7 中,从而在图中固化了静态值 10。
动态值#
动态值是可以从一次运行到另一次运行而改变的值。它的行为就像一个“正常”的函数参数:您可以传递不同的输入并期望您的函数执行正确。
哪些值是静态的,哪些是动态的?#
一个值是静态还是动态取决于它的类型
对于 Tensor
Tensor 的数据被视为动态。
Tensor 的形状可以被系统视为静态或动态。
默认情况下,所有输入 Tensor 的形状都被视为静态。用户可以通过为任何输入 Tensor 指定 动态形状 来覆盖此行为。
作为模块状态一部分的 Tensor,即参数和缓冲区,始终具有静态形状。
Tensor 的其他形式的元数据(例如
device、dtype)是静态的。
Python原始类型(
int、float、bool、str、None)是静态的。对于某些原始类型有动态变体(
SymInt、SymFloat、SymBool)。通常用户不必处理它们。用户可以通过为其指定 动态形状 来将整数输入指定为动态。
对于 Python标准容器(
list、tuple、dict、namedtuple)结构(即
list和tuple的长度,dict和namedtuple的键序列)是静态的。包含的元素递归地应用这些规则(基本上是 PyTree 方案),其中叶子是 Tensor 或原始类型。
其他类(包括数据类)可以注册到 PyTree(见下文),并遵循与标准容器相同的规则。
输入类型#
输入将根据其类型(如上所述)被视为静态或动态。
静态输入将被硬编码到图中,在运行时传递不同的值将导致错误。请记住,这些主要是原始类型的值。
动态输入就像“正常”函数输入一样。请记住,这些主要是 Tensor 类型的值。
默认情况下,您可以使用以下类型的输入来运行程序:
张量
Python 原始类型(
int、float、bool、str、None)Python 标准容器(
list、tuple、dict、namedtuple)
自定义输入类型(PyTree)#
此外,您还可以定义自己的(自定义)类并将其用作输入类型,但您需要将此类注册为 PyTree。
以下是一个使用实用程序注册用作输入类型的数据类的示例。
@dataclass
class Input:
f: torch.Tensor
p: torch.Tensor
import torch.utils._pytree as pytree
pytree.register_dataclass(Input)
class M(torch.nn.Module):
def forward(self, x: Input):
return x.f + 1
torch.export.export(M(), (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),))
可选输入类型#
对于程序的可选输入但未传入的参数,torch.export.export() 将专门化其默认值。因此,导出的程序将要求用户显式传入所有参数,并且会丢失默认行为。例如
class M(torch.nn.Module):
def forward(self, x, y=None):
if y is not None:
return y * x
return x + x
# Optional input is passed in
ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(3, 3)))
print(ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"):
# File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x
mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x); y = x = None
return (mul,)
"""
# Optional input is not passed in
ep = torch.export.export(M(), (torch.randn(3, 3),))
print(ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]", y):
# File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x); x = None
return (add,)
"""
控制流:静态与动态#
torch.export.export() 支持控制流。控制流的行为取决于您分支的值是静态还是动态。
静态控制流#
基于静态值的 Python 控制流得到透明支持。(请记住,静态值包括静态形状,因此基于静态形状的控制流也包含在此情况中。)
如上所述,我们“固化”静态值,因此导出的图将永远不会看到任何基于静态值的控制流。
在 if 语句的情况下,我们将继续追踪导出时选择的分支。在 for 或 while 语句的情况下,我们将通过展开循环继续追踪。
动态控制流:依赖形状 vs. 依赖数据#
当控制流中涉及的值是动态的时,它可能依赖于动态形状或动态数据。考虑到编译器使用形状信息而不是数据进行追踪,这些情况对编程模型的影响是不同的。
依赖动态形状的控制流#
当控制流中涉及的值是 动态形状 时,在大多数情况下我们也能在追踪期间得知动态形状的具体值:关于编译器如何追踪此信息,请参阅以下部分。
在这些情况下,我们称控制流为依赖形状的。我们使用动态形状的具体值来评估条件,以使其为 True 或 False,然后继续追踪(如上所述),并额外发出一个对应于刚刚评估的条件的 guard。
否则,控制流被视为依赖数据的。我们无法将条件评估为 True 或 False,因此无法继续追踪,必须在导出时引发错误。请参阅下一节。
依赖数据的控制流#
支持对动态值进行依赖数据的控制流,但您必须使用 PyTorch 的显式算子之一才能继续追踪。不允许在动态值上使用 Python 控制流语句,因为编译器无法评估继续追踪所需的条件,因此必须在导出时引发错误。
我们提供用于表示动态值上的通用条件和循环的算子,例如 torch.cond、torch.map。请注意,只有当您确实想要依赖数据的控制流时,才需要使用它们。
以下是一个依赖数据条件(x.sum() > 0,其中 x 是输入 Tensor)的 if 语句,使用 torch.cond 重写。它不必决定追踪哪个分支,现在两个分支都被追踪了。
class M_old(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x.sin()
else:
return x.cos()
class M_new(torch.nn.Module):
def forward(self, x):
return torch.cond(
pred=x.sum() > 0,
true_fn=lambda x: x.sin(),
false_fn=lambda x: x.cos(),
operands=(x,),
)
数据相关控制流的一个特例是涉及 依赖数据的动态形状:通常,某些中间 Tensor 的形状依赖于输入数据而不是输入形状(因此不依赖形状)。在这种情况下,您可以使用断言来决定条件是 True 还是 False,而不是使用控制流算子。给定这样的断言,我们可以继续追踪,并像上面一样发出 guard。
我们提供用于表示动态形状断言的算子,例如 torch._check。请注意,只有当存在依赖数据的动态形状上的控制流时,才需要使用它。
以下是一个涉及依赖数据的动态形状(nz.shape[0] > 0,其中 nz 是调用 torch.nonzero() 的结果,一个输出形状依赖于输入数据的算子)的 if 语句。您可以通过使用 torch._check 添加断言来有效地决定追踪哪个分支,而无需重写。
class M_old(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
if nz.shape[0] > 0:
return x.sin()
else:
return x.cos()
class M_new(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
torch._check(nz.shape[0] > 0)
if nz.shape[0] > 0:
return x.sin()
else:
return x.cos()
符号形状基础知识#
在追踪过程中,动态 Tensor 形状及其上的条件被编码为“符号表达式”。(相比之下,静态 Tensor 形状及其上的条件就是 int 和 bool 值。)
符号类似于变量;它描述了一个动态 Tensor 形状。
随着追踪的进行,中间 Tensor 的形状可能由更通用的表达式描述,通常涉及整数算术运算符。这是因为对于大多数 PyTorch 算子,输出 Tensor 的形状可以描述为输入 Tensor 形状的函数。例如,torch.cat() 的输出形状是其输入形状的总和。
此外,当我们遇到程序中的控制流时,我们会创建布尔表达式,通常涉及关系运算符,以描述追踪路径上的条件。这些表达式被评估以决定要追踪的程序路径,并记录在 形状环境中,以保护追踪路径的正确性并评估后续创建的表达式。
我们将在下面简要介绍这些子系统。
PyTorch 算子的 Fake 实现#
请注意,在追踪过程中,我们使用 Fake Tensor 执行程序,这些 Fake Tensor 没有数据。通常我们无法使用 Fake Tensor 调用实际的 PyTorch 算子实现。因此,每个算子都需要有一个额外的 Fake(也称为“meta”)实现,该实现接受和输出 Fake Tensor,并在形状和其他由 Fake Tensor 携带的元数据方面与实际实现的行为匹配。
例如,注意 torch.index_select() 的 Fake 实现如何使用输入形状来计算输出形状(同时忽略输入数据并返回空输出数据)。
def meta_index_select(self, dim, index):
result_size = list(self.size())
if self.dim() > 0:
result_size[dim] = index.numel()
return self.new_empty(result_size)
形状传播:Backed 与 Unbacked 动态形状#
形状使用 PyTorch 算子的 Fake 实现进行传播。
理解动态形状传播的一个关键概念是backed(已支持)和unbacked(未支持)动态形状之间的区别:前者我们知道具体值,后者则不知道。
形状的传播,包括追踪 backed 和 unbacked 动态形状,过程如下:
代表输入的 Tensor 的形状可以是静态的,也可以是动态的。当是动态的时,它们由符号描述;此外,这些符号是 backed 的,因为在导出时我们还知道用户提供的“真实”示例输入的具体值。
算子的输出形状由其 Fake 实现计算,可以是静态的,也可以是动态的。当是动态的时,通常由一个符号表达式描述。此外:
如果输出形状仅取决于输入形状,则当所有输入形状都是静态的或 backed 动态的时,它就是静态的或 backed 动态的。
另一方面,如果输出形状依赖于输入数据,它必然是动态的,而且,因为我们不知道其具体值,所以它是 unbacked 的。
控制流:Guards 和 Assertions#
当遇到形状上的条件时,它要么只涉及静态形状,在这种情况下它是 bool,要么涉及动态形状,在这种情况下它是符号布尔表达式。对于后者:
当条件仅涉及 backed 动态形状时,我们可以使用这些动态形状的具体值将条件评估为
True或False。然后,我们可以向形状环境添加一个 guard,声明相应的符号布尔表达式为True或False,并继续追踪。否则,条件涉及 unbacked 动态形状。通常,我们无法在没有额外信息的情况下评估这种条件;因此,我们无法继续追踪,并且必须在导出时引发错误。用户应使用显式的 PyTorch 算子进行追踪以继续。此信息将作为 guard 添加到形状环境中,并且还可能有助于将其他后续遇到的条件评估为
True或False。
模型导出后,任何关于 backed 动态形状的 guard 都可以被理解为对输入动态形状的条件。这些条件会与必须提供给 export 的动态形状规范进行验证,该规范描述了不仅示例输入,而且所有未来输入都应满足的动态形状条件,以确保生成代码的正确性。更准确地说,动态形状规范必须逻辑上包含生成的 guard,否则将在导出时引发错误(并提供对动态形状规范的建议修复)。另一方面,当没有关于 backed 动态形状的 guard 生成时(特别是在所有形状都是静态的时),则不需要向 export 提供动态形状规范。通常,动态形状规范会被转换为生成代码的运行时断言。
最后,关于 unbacked 动态形状的任何 guard 都将被转换为“内联”运行时断言。这些断言被添加到生成代码中,在创建那些 unbacked 动态形状的位置:通常是在数据相关算子调用之后。
允许的 PyTorch 算子#
允许使用所有 PyTorch 算子。
自定义算子#
此外,您还可以定义和使用 自定义算子。定义自定义算子包括为其定义 Fake 实现,就像任何其他 PyTorch 算子一样(参见上一节)。
这是一个包装 NumPy 的自定义 sin 算子的示例,以及它注册的(平凡)Fake 实现。
@torch.library.custom_op("mylib::sin", mutates_args=())
def sin(x: Tensor) -> Tensor:
x_np = x.numpy()
y_np = np.sin(x_np)
return torch.from_numpy(y_np)
@torch.library.register_fake("mylib::sin")
def _(x: Tensor) -> Tensor:
return torch.empty_like(x)
有时您的自定义算子的 Fake 实现会涉及依赖数据的形状。以下是一个自定义 nonzero 的 Fake 实现可能的样子。
...
@torch.library.register_fake("mylib::custom_nonzero")
def _(x):
nnz = torch.library.get_ctx().new_dynamic_size()
shape = [nnz, x.dim()]
return x.new_empty(shape, dtype=torch.int64)
模块状态:读取 vs. 更新#
模块状态包括参数、缓冲区和常规属性。
常规属性可以是任何类型。
另一方面,参数和缓冲区始终是 Tensor。
模块状态可以是动态的或静态的,具体取决于它们的类型(如上所述)。例如,self.training 是一个 bool,这意味着它是静态的;另一方面,任何参数或缓冲区都是动态的。
模块状态中包含的任何 Tensor 的形状都不能是动态的,即这些形状在导出时是固定的,并且不能在导出程序的执行之间改变。
访问规则#
所有模块状态都必须初始化。访问尚未初始化的模块状态将在导出时引发错误。
读取模块状态总是被允许的。.
更新模块状态是可能的,但必须遵循以下规则:
静态常规属性(例如,原始类型)可以被更新。读取和更新可以自由交错,并且正如预期的那样,任何读取都将始终看到最新更新的值。因为这些属性是静态的,所以我们也将固化这些值,因此生成代码将不会有任何实际“获取”或“设置”这些属性的指令。
动态常规属性(例如,Tensor 类型)不能被更新。要做到这一点,它必须在模块初始化期间注册为缓冲区。
缓冲区可以被更新,更新可以是原地更新(例如,
self.buffer[:] = ...)或非原地更新(例如,self.buffer = ...)。参数不能被更新。通常,参数仅在训练期间更新,而在推理期间不更新。我们建议使用
torch.no_grad()进行导出,以避免在导出时更新参数。