torch.export 编程模型#
创建于:2024年12月18日 | 最后更新于:2025年6月11日
本文档旨在解释 torch.export.export()
的行为和能力。旨在帮助您直观理解 torch.export.export()
如何处理代码。
追踪基础知识#
torch.export.export()
通过追踪模型在“示例”输入上的执行,并记录观察到的 PyTorch 操作和条件来捕获代表模型的图。然后,该图可以针对满足相同条件的其他输入进行运行。
torch.export.export()
的基本输出是单个 PyTorch 操作图,以及相关的元数据。输出的具体格式在 torch.export IR 规范 中有详细介绍。
严格追踪与非严格追踪#
torch.export.export()
提供两种追踪模式。
在非严格模式下,我们通过正常的 Python 解释器来跟踪程序。您的代码的执行方式与贪婪模式完全相同;唯一的区别是所有 Tensor 都被替换为 Fake Tensors,它们具有形状和其他元数据,但没有数据,并被包装在 Proxy 对象 中,这些对象会将所有对它们的运算记录到一个图中。我们还捕获 Tensor 形状的条件,这些条件用于保证生成代码的正确性。
在严格模式下,我们首先通过 TorchDynamo(一个 Python 字节码分析引擎)来跟踪程序。TorchDynamo 实际上不执行您的 Python 代码。相反,它对其进行符号化分析,并根据结果构建一个图。一方面,这种分析允许 torch.export.export()
提供额外的 Python 级别安全保证(除了捕获 Tensor 形状的条件,如非严格模式)。另一方面,并非所有 Python 功能都支持此分析。
尽管目前跟踪的默认模式是严格模式,但我们强烈建议使用非严格模式,它将很快成为默认模式。对于大多数模型而言,Tensor 形状的条件足以保证正确性,而额外的 Python 级别安全保证没有影响;同时,在 TorchDynamo 中遇到不支持的 Python 功能的可能性会带来不必要的风险。
在本文档的其余部分,我们假定我们正在以非严格模式进行跟踪;特别是,我们假定所有 Python 功能都得到支持。
值:静态 vs. 动态#
理解 torch.export.export()
行为的关键概念是静态和动态值之间的区别。
静态值#
静态值是在导出时固定的,并且在导出程序的不同执行之间不会改变。在跟踪期间遇到该值时,我们将其视为常量并将其硬编码到图中。
当执行一个运算(例如 x + y
)并且所有输入都是静态的,该运算的输出将直接硬编码到图中,该运算本身则不会显示(即它被“常量折叠”了)。
当一个值被硬编码到图中时,我们说该图已经被特化为该值。例如
import torch
class MyMod(torch.nn.Module):
def forward(self, x, y):
z = y + 7
return x + z
m = torch.export.export(MyMod(), (torch.randn(1), 3))
print(m.graph_module.code)
"""
def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(arg0_1, 10); arg0_1 = None
return (add,)
"""
在这里,我们将 3
作为 y
的跟踪值;它被视为静态值并加到 7
上,在图中烧入静态值 10
。
动态值#
动态值是每次运行时都可能改变的值。它的行为就像一个“普通”的函数参数:您可以传递不同的输入并期望函数能够正确执行。
哪些值是静态的,哪些是动态的?#
一个值是静态还是动态取决于它的类型
对于 Tensor
Tensor 的数据被视为动态。
Tensor 的形状可以被系统视为静态或动态。
默认情况下,所有输入 Tensor 的形状都被视为静态。用户可以通过为任何输入 Tensor 指定一个动态形状来覆盖此行为。
作为模块状态一部分的 Tensor(即参数和缓冲区)始终具有静态形状。
其他形式的 Tensor元数据(例如
device
、dtype
)是静态的。
Python基本类型(
int
、float
、bool
、str
、None
)是静态的。有一些基本类型的动态变体(
SymInt
、SymFloat
、SymBool
)。通常用户不必处理它们。
对于 Python标准容器(
list
、tuple
、dict
、namedtuple
)结构(即
list
和tuple
的长度,以及dict
和namedtuple
的键序列)是静态的。包含的元素将递归地应用这些规则(基本上是 PyTree 方案),叶子是 Tensor 或基本类型。
其他类(包括数据类)可以注册为 PyTree(见下文),并遵循与标准容器相同的规则。
输入类型#
输入将根据其类型(如上所述)被视为静态或动态。
静态输入将被硬编码到图中,在运行时传递不同的值将导致错误。请记住,这些大多是基本类型的值。
动态输入表现得像“普通”函数输入。请记住,这些大多是 Tensor 类型的值。
默认情况下,您可以用于程序的输入类型是
张量
Python 基本类型(
int
、float
、bool
、str
、None
)Python 标准容器(
list
、tuple
、dict
、namedtuple
)
自定义输入类型#
此外,您还可以定义自己的(自定义)类并将其用作输入类型,但您需要将此类注册为 PyTree。
这是一个使用实用工具注册用作输入类型的 dataclass 的示例。
@dataclass
class Input:
f: torch.Tensor
p: torch.Tensor
torch.export.register_dataclass(Input)
class M(torch.nn.Module):
def forward(self, x: Input):
return x.f + 1
torch.export.export(M(), (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),))
可选输入类型#
对于未传入程序的可选输入,torch.export.export()
将特化为它们的默认值。因此,导出的程序将要求用户显式传入所有参数,并会丢失默认行为。例如
class M(torch.nn.Module):
def forward(self, x, y=None):
if y is not None:
return y * x
return x + x
# Optional input is passed in
ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(3, 3)))
print(ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"):
# File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x
mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x); y = x = None
return (mul,)
"""
# Optional input is not passed in
ep = torch.export.export(M(), (torch.randn(3, 3),))
print(ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]", y):
# File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x); x = None
return (add,)
"""
控制流:静态 vs. 动态#
torch.export.export()
支持控制流。控制流的行为取决于您分支的值是静态还是动态。
静态控制流#
关于静态值的 Python 控制流被透明地支持。(回想一下,静态值包括静态形状,因此关于静态形状的控制流也包含在此情况中。)
如上所述,我们“烧入”静态值,因此导出的图将永远不会看到关于静态值的任何控制流。
在 if
语句的情况下,我们将继续跟踪导出时所采取的分支。在 for
或 while
语句的情况下,我们将通过展开循环来继续跟踪。
动态控制流:依赖形状 vs. 依赖数据#
当控制流中涉及的值是动态的时,它可能依赖于动态形状或动态数据。鉴于编译器使用形状信息而不是数据进行跟踪,在这些情况下对编程模型的影响是不同的。
动态依赖形状的控制流#
当控制流中涉及的值是动态形状时,在大多数情况下我们也会在跟踪时知道动态形状的具体值:有关编译器如何跟踪此信息的更多详细信息,请参阅下一节。
在这些情况下,我们称控制流是依赖形状的。我们使用动态形状的具体值来评估条件,以确定是True
还是False
,然后继续跟踪(如上所述),并额外发出一个对应于刚刚评估的条件的 guard。
否则,控制流被认为是依赖数据的。我们无法将条件评估为True
或False
,因此无法继续跟踪,并且必须在导出时引发错误。请参阅下一节。
动态依赖数据的控制流#
支持动态值的依赖数据的控制流,但您必须使用 PyTorch 的显式运算符之一来继续跟踪。不允许在动态值上使用 Python 控制流语句,因为编译器无法评估继续跟踪所需的条件,因此必须在导出时引发错误。
我们提供用于表示动态值的通用条件和循环的运算符,例如 torch.cond
、torch.map
。请注意,只有当您确实想要依赖数据的控制流时,才需要使用它们。
这是一个关于数据依赖条件 x.sum() > 0
的 if
语句的示例,其中 x
是一个输入 Tensor,使用 torch.cond
重写。与其必须决定跟踪哪个分支,不如现在两个分支都会被跟踪。
class M_old(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x.sin()
else:
return x.cos()
class M_new(torch.nn.Module):
def forward(self, x):
return torch.cond(
pred=x.sum() > 0,
true_fn=lambda x: x.sin(),
false_fn=lambda x: x.cos(),
operands=(x,),
)
数据依赖控制流的一个特殊情况是当它涉及依赖数据的动态形状时:通常,某些中间 Tensor 的形状依赖于输入数据而不是输入形状(因此不是依赖形状的)。在这种情况下,您可以使用断言来决定条件是True
还是False
,而不是使用控制流运算符。给定这样的断言,我们可以继续跟踪,并发出一个 guard。
我们提供用于表示动态形状断言的运算符,例如 torch._check
。请注意,只有当存在依赖数据的动态形状的控制流时,您才需要使用它。
这是一个关于涉及依赖数据的动态形状的条件 nz.shape[0] > 0
的 if
语句的示例,其中 nz
是调用 torch.nonzero()
的结果,这是一个输出形状依赖于输入数据的运算符。与其重写它,不如使用 torch._check
添加断言来有效地决定跟踪哪个分支。
class M_old(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
if nz.shape[0] > 0:
return x.sin()
else:
return x.cos()
class M_new(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
torch._check(nz.shape[0] > 0)
if nz.shape[0] > 0:
return x.sin()
else:
return x.cos()
符号形状基础#
在跟踪期间,动态 Tensor 形状和关于它们的条件被编码为“符号表达式”。(相比之下,静态 Tensor 形状和关于它们的条件仅仅是 int
和 bool
值。)
符号就像一个变量;它描述了一个动态 Tensor 形状。
随着跟踪的进行,中间 Tensor 的形状可能由更通用的表达式描述,通常涉及整数算术运算符。这是因为对于大多数 PyTorch 运算符,输出 Tensor 的形状可以描述为输入 Tensor 形状的函数。例如,torch.cat()
的输出形状是其输入形状的总和。
此外,当我们遇到程序中的控制流时,我们会创建布尔表达式,通常涉及关系运算符,以描述沿跟踪路径的条件。这些表达式用于决定通过程序的哪个路径进行跟踪,并记录在形状环境中,以保证跟踪路径的正确性并评估随后创建的表达式。
我们将在下面简要介绍这些子系统。
PyTorch 运算符的 Fake 实现#
回想一下,在跟踪期间,我们使用Fake Tensors(没有数据)来执行程序。一般情况下,我们不能用 Fake Tensors 调用实际的 PyTorch 运算符实现。因此,每个运算符都需要有一个额外的 Fake(也称为“meta”)实现,它接收和输出 Fake Tensors,并且在形状和 Fake Tensors 所携带的其他元数据方面与实际实现的行为匹配。
例如,注意 torch.index_select()
的 Fake 实现如何使用输入形状计算输出形状(同时忽略输入数据并返回空输出数据)。
def meta_index_select(self, dim, index):
result_size = list(self.size())
if self.dim() > 0:
result_size[dim] = index.numel()
return self.new_empty(result_size)
形状传播:已支持 vs. 未支持的动态形状#
形状通过 PyTorch 运算符的 Fake 实现进行传播。
要理解动态形状的传播,特别是其传播方式,一个关键概念是已支持(backed)和未支持(unbacked)动态形状之间的区别:我们知道前者(已支持)的具体值,但不知道后者(未支持)的具体值。
形状的传播,包括跟踪已支持和未支持的动态形状,按如下方式进行
表示输入的 Tensor 的形状可以是静态的或动态的。当是动态的时,它们由符号描述;此外,这些符号是已支持的,因为我们还知道它们在导出时由用户提供的“真实”示例输入的具体值。
运算符的输出形状由其 Fake 实现计算,可以是静态的或动态的。当是动态的时,通常由符号表达式描述。此外
如果输出形状仅依赖于输入形状,则当所有输入形状都是静态的或已支持的动态的时,它要么是静态的,要么是已支持的动态。
另一方面,如果输出形状依赖于输入数据,那么它必然是动态的,而且,因为我们无法知道其具体值,所以它是未支持的。
控制流:Guards 和 Assertions#
遇到形状条件时,它要么只涉及静态形状,在这种情况下它是一个 bool
,要么涉及动态形状,在这种情况下它是一个符号布尔表达式。对于后者
当条件仅涉及已支持的动态形状时,我们可以使用这些动态形状的具体值将条件评估为
True
或False
。然后,我们可以将一个 guard 添加到形状环境中,声明相应的符号布尔表达式为True
或False
,然后继续跟踪。否则,条件涉及未支持的动态形状。通常,我们无法在没有额外信息的情况下评估此类条件;因此,我们无法继续跟踪,并且必须在导出时引发错误。用户应该使用显式的 PyTorch 运算符进行跟踪以继续。这些信息作为 guard 添加到形状环境中,并且还可能帮助评估随后遇到的其他条件为
True
或False
。
模型导出后,任何关于已支持动态形状的 guard 都可以被理解为关于输入动态形状的条件。这些条件将针对导出时必须提供的动态形状规范进行验证,该规范描述了不仅示例输入,而且所有未来输入都应满足的动态形状条件,以保证生成代码的正确性。更准确地说,动态形状规范必须逻辑上暗示生成的 guard,否则将在导出时引发错误(并提供对动态形状规范的修复建议)。另一方面,当没有关于已支持动态形状的 guard 时(特别是当所有形状都是静态的时),则无需向导出提供动态形状规范。通常,动态形状规范被转换为生成代码的输入的运行时断言。
最后,任何关于未支持动态形状的 guard 都被转换为“内联”运行时断言。这些断言将添加到生成代码中,在那些未支持动态形状被创建的位置:通常是数据依赖运算符调用之后。
允许的 PyTorch 运算符#
所有 PyTorch 运算符都是允许的。
自定义运算符#
此外,您还可以定义和使用自定义运算符。定义自定义运算符包括为其定义 Fake 实现,就像任何其他 PyTorch 运算符一样(参见上一节)。
这是一个包装 NumPy 的自定义 sin
运算符及其注册的(平凡)Fake 实现的示例。
@torch.library.custom_op("mylib::sin", mutates_args=())
def sin(x: Tensor) -> Tensor:
x_np = x.numpy()
y_np = np.sin(x_np)
return torch.from_numpy(y_np)
@torch.library.register_fake("mylib::sin")
def _(x: Tensor) -> Tensor:
return torch.empty_like(x)
有时您的自定义运算符的 Fake 实现会涉及依赖数据的形状。以下是自定义 nonzero
的 Fake 实现可能的样子。
...
@torch.library.register_fake("mylib::custom_nonzero")
def _(x):
nnz = torch.library.get_ctx().new_dynamic_size()
shape = [nnz, x.dim()]
return x.new_empty(shape, dtype=torch.int64)
模块状态:读取 vs. 更新#
模块状态包括参数、缓冲区和常规属性。
常规属性可以是任何类型。
另一方面,参数和缓冲区始终是 Tensor。
模块状态可以是动态的或静态的,具体取决于它们的类型(如上所述)。例如,self.training
是一个 bool
,这意味着它是静态的;另一方面,任何参数或缓冲区都是动态的。
模块状态中包含的任何 Tensor 的形状都不能是动态的,也就是说,这些形状在导出时是固定的,并且在导出程序的执行之间不能改变。
访问规则#
所有模块状态都必须初始化。访问尚未初始化的模块状态将在导出时引发错误。
读取模块状态始终是允许的.
更新模块状态是可能的,但必须遵循以下规则
一个静态的常规属性(例如,基本类型)可以被更新。读取和更新可以自由地交错,并且如预期,任何读取都将始终看到最新更新的值。由于这些属性是静态的,我们也将烧入值,因此生成的代码将不会有任何实际“获取”或“设置”此类属性的指令。
一个动态的常规属性(例如,Tensor 类型)不能被更新。要做到这一点,它必须在模块初始化期间注册为缓冲区。
缓冲区可以被更新,其中更新可以是就地(例如,
self.buffer[:] = ...
)或非就地(例如,self.buffer = ...
)。参数不能被更新。通常,参数仅在训练期间更新,而不是在推理期间更新。我们建议使用
torch.no_grad()
进行导出,以避免在导出时更新参数。