评价此页

torch.export 编程模型#

创建于:2024年12月18日 | 最后更新于:2025年6月11日

本文档旨在解释 torch.export.export() 的行为和能力。旨在帮助您直观理解 torch.export.export() 如何处理代码。

追踪基础知识#

torch.export.export() 通过追踪模型在“示例”输入上的执行,并记录观察到的 PyTorch 操作和条件来捕获代表模型的图。然后,该图可以针对满足相同条件的其他输入进行运行。

torch.export.export() 的基本输出是单个 PyTorch 操作图,以及相关的元数据。输出的具体格式在 torch.export IR 规范 中有详细介绍。

严格追踪与非严格追踪#

torch.export.export() 提供两种追踪模式。

非严格模式下,我们通过正常的 Python 解释器来跟踪程序。您的代码的执行方式与贪婪模式完全相同;唯一的区别是所有 Tensor 都被替换为 Fake Tensors它们具有形状和其他元数据,但没有数据,并被包装在 Proxy 对象 中,这些对象会将所有对它们的运算记录到一个图中。我们还捕获 Tensor 形状的条件这些条件用于保证生成代码的正确性

严格模式下,我们首先通过 TorchDynamo(一个 Python 字节码分析引擎)来跟踪程序。TorchDynamo 实际上不执行您的 Python 代码。相反,它对其进行符号化分析,并根据结果构建一个图。一方面,这种分析允许 torch.export.export() 提供额外的 Python 级别安全保证(除了捕获 Tensor 形状的条件,如非严格模式)。另一方面,并非所有 Python 功能都支持此分析。

尽管目前跟踪的默认模式是严格模式,但我们强烈建议使用非严格模式,它将很快成为默认模式。对于大多数模型而言,Tensor 形状的条件足以保证正确性,而额外的 Python 级别安全保证没有影响;同时,在 TorchDynamo 中遇到不支持的 Python 功能的可能性会带来不必要的风险。

在本文档的其余部分,我们假定我们正在以非严格模式进行跟踪;特别是,我们假定所有 Python 功能都得到支持

值:静态 vs. 动态#

理解 torch.export.export() 行为的关键概念是静态动态值之间的区别。

静态值#

静态值是在导出时固定的,并且在导出程序的不同执行之间不会改变。在跟踪期间遇到该值时,我们将其视为常量并将其硬编码到图中。

当执行一个运算(例如 x + y)并且所有输入都是静态的,该运算的输出将直接硬编码到图中,该运算本身则不会显示(即它被“常量折叠”了)。

当一个值被硬编码到图中时,我们说该图已经被特化为该值。例如

import torch

class MyMod(torch.nn.Module):
    def forward(self, x, y):
        z = y + 7
        return x + z

m = torch.export.export(MyMod(), (torch.randn(1), 3))
print(m.graph_module.code)

"""
def forward(self, arg0_1, arg1_1):
    add = torch.ops.aten.add.Tensor(arg0_1, 10);  arg0_1 = None
    return (add,)

"""

在这里,我们将 3 作为 y 的跟踪值;它被视为静态值并加到 7 上,在图中烧入静态值 10

动态值#

动态值是每次运行时都可能改变的值。它的行为就像一个“普通”的函数参数:您可以传递不同的输入并期望函数能够正确执行。

哪些值是静态的,哪些是动态的?#

一个值是静态还是动态取决于它的类型

  • 对于 Tensor

    • Tensor 的数据被视为动态。

    • Tensor 的形状可以被系统视为静态或动态。

      • 默认情况下,所有输入 Tensor 的形状都被视为静态。用户可以通过为任何输入 Tensor 指定一个动态形状来覆盖此行为。

      • 作为模块状态一部分的 Tensor(即参数和缓冲区)始终具有静态形状。

    • 其他形式的 Tensor元数据(例如 devicedtype)是静态的。

  • Python基本类型intfloatboolstrNone)是静态的。

    • 有一些基本类型的动态变体(SymIntSymFloatSymBool)。通常用户不必处理它们。

  • 对于 Python标准容器listtupledictnamedtuple

    • 结构(即 listtuple 的长度,以及 dictnamedtuple 的键序列)是静态的。

    • 包含的元素将递归地应用这些规则(基本上是 PyTree 方案),叶子是 Tensor 或基本类型。

  • 其他(包括数据类)可以注册为 PyTree(见下文),并遵循与标准容器相同的规则。

输入类型#

输入将根据其类型(如上所述)被视为静态或动态。

  • 静态输入将被硬编码到图中,在运行时传递不同的值将导致错误。请记住,这些大多是基本类型的值。

  • 动态输入表现得像“普通”函数输入。请记住,这些大多是 Tensor 类型的值。

默认情况下,您可以用于程序的输入类型是

  • 张量

  • Python 基本类型(intfloatboolstrNone

  • Python 标准容器(listtupledictnamedtuple

自定义输入类型#

此外,您还可以定义自己的(自定义)类并将其用作输入类型,但您需要将此类注册为 PyTree。

这是一个使用实用工具注册用作输入类型的 dataclass 的示例。

@dataclass
class Input:
    f: torch.Tensor
    p: torch.Tensor

torch.export.register_dataclass(Input)

class M(torch.nn.Module):
    def forward(self, x: Input):
        return x.f + 1

torch.export.export(M(), (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),))

可选输入类型#

对于未传入程序的可选输入,torch.export.export() 将特化为它们的默认值。因此,导出的程序将要求用户显式传入所有参数,并会丢失默认行为。例如

class M(torch.nn.Module):
    def forward(self, x, y=None):
        if y is not None:
            return y * x
        return x + x

# Optional input is passed in
ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(3, 3)))
print(ep)
"""
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"):
            # File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x
            mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x);  y = x = None
            return (mul,)
"""

# Optional input is not passed in
ep = torch.export.export(M(), (torch.randn(3, 3),))
print(ep)
"""
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]", y):
            # File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x
            add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x);  x = None
            return (add,)
"""

控制流:静态 vs. 动态#

torch.export.export() 支持控制流。控制流的行为取决于您分支的值是静态还是动态。

静态控制流#

关于静态值的 Python 控制流被透明地支持。(回想一下,静态值包括静态形状,因此关于静态形状的控制流也包含在此情况中。)

如上所述,我们“烧入”静态值,因此导出的图将永远不会看到关于静态值的任何控制流。

if 语句的情况下,我们将继续跟踪导出时所采取的分支。在 forwhile 语句的情况下,我们将通过展开循环来继续跟踪。

动态控制流:依赖形状 vs. 依赖数据#

当控制流中涉及的值是动态的时,它可能依赖于动态形状或动态数据。鉴于编译器使用形状信息而不是数据进行跟踪,在这些情况下对编程模型的影响是不同的。

动态依赖形状的控制流#

当控制流中涉及的值是动态形状时,在大多数情况下我们也会在跟踪时知道动态形状的具体值:有关编译器如何跟踪此信息的更多详细信息,请参阅下一节。

在这些情况下,我们称控制流是依赖形状的。我们使用动态形状的具体值来评估条件,以确定是True还是False,然后继续跟踪(如上所述),并额外发出一个对应于刚刚评估的条件的 guard。

否则,控制流被认为是依赖数据的。我们无法将条件评估为TrueFalse,因此无法继续跟踪,并且必须在导出时引发错误。请参阅下一节。

动态依赖数据的控制流#

支持动态值的依赖数据的控制流,但您必须使用 PyTorch 的显式运算符之一来继续跟踪。不允许在动态值上使用 Python 控制流语句,因为编译器无法评估继续跟踪所需的条件,因此必须在导出时引发错误。

我们提供用于表示动态值的通用条件和循环的运算符,例如 torch.condtorch.map。请注意,只有当您确实想要依赖数据的控制流时,才需要使用它们。

这是一个关于数据依赖条件 x.sum() > 0if 语句的示例,其中 x 是一个输入 Tensor,使用 torch.cond 重写。与其必须决定跟踪哪个分支,不如现在两个分支都会被跟踪。

class M_old(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x.sin()
        else:
            return x.cos()

class M_new(torch.nn.Module):
    def forward(self, x):
        return torch.cond(
            pred=x.sum() > 0,
            true_fn=lambda x: x.sin(),
            false_fn=lambda x: x.cos(),
            operands=(x,),
        )

数据依赖控制流的一个特殊情况是当它涉及依赖数据的动态形状时:通常,某些中间 Tensor 的形状依赖于输入数据而不是输入形状(因此不是依赖形状的)。在这种情况下,您可以使用断言来决定条件是True还是False,而不是使用控制流运算符。给定这样的断言,我们可以继续跟踪,并发出一个 guard。

我们提供用于表示动态形状断言的运算符,例如 torch._check。请注意,只有当存在依赖数据的动态形状的控制流时,您才需要使用它。

这是一个关于涉及依赖数据的动态形状的条件 nz.shape[0] > 0if 语句的示例,其中 nz 是调用 torch.nonzero() 的结果,这是一个输出形状依赖于输入数据的运算符。与其重写它,不如使用 torch._check 添加断言来有效地决定跟踪哪个分支。

class M_old(torch.nn.Module):
    def forward(self, x):
        nz = x.nonzero()
        if nz.shape[0] > 0:
            return x.sin()
        else:
            return x.cos()

class M_new(torch.nn.Module):
    def forward(self, x):
        nz = x.nonzero()
        torch._check(nz.shape[0] > 0)
        if nz.shape[0] > 0:
            return x.sin()
        else:
            return x.cos()

符号形状基础#

在跟踪期间,动态 Tensor 形状和关于它们的条件被编码为“符号表达式”。(相比之下,静态 Tensor 形状和关于它们的条件仅仅是 intbool 值。)

符号就像一个变量;它描述了一个动态 Tensor 形状。

随着跟踪的进行,中间 Tensor 的形状可能由更通用的表达式描述,通常涉及整数算术运算符。这是因为对于大多数 PyTorch 运算符,输出 Tensor 的形状可以描述为输入 Tensor 形状的函数。例如,torch.cat() 的输出形状是其输入形状的总和。

此外,当我们遇到程序中的控制流时,我们会创建布尔表达式,通常涉及关系运算符,以描述沿跟踪路径的条件。这些表达式用于决定通过程序的哪个路径进行跟踪,并记录在形状环境中,以保证跟踪路径的正确性并评估随后创建的表达式。

我们将在下面简要介绍这些子系统。

PyTorch 运算符的 Fake 实现#

回想一下,在跟踪期间,我们使用Fake Tensors(没有数据)来执行程序。一般情况下,我们不能用 Fake Tensors 调用实际的 PyTorch 运算符实现。因此,每个运算符都需要有一个额外的 Fake(也称为“meta”)实现,它接收和输出 Fake Tensors,并且在形状和 Fake Tensors 所携带的其他元数据方面与实际实现的行为匹配。

例如,注意 torch.index_select() 的 Fake 实现如何使用输入形状计算输出形状(同时忽略输入数据并返回空输出数据)。

def meta_index_select(self, dim, index):
    result_size = list(self.size())
    if self.dim() > 0:
        result_size[dim] = index.numel()
    return self.new_empty(result_size)

形状传播:已支持 vs. 未支持的动态形状#

形状通过 PyTorch 运算符的 Fake 实现进行传播。

要理解动态形状的传播,特别是其传播方式,一个关键概念是已支持(backed)和未支持(unbacked)动态形状之间的区别:我们知道前者(已支持)的具体值,但不知道后者(未支持)的具体值。

形状的传播,包括跟踪已支持和未支持的动态形状,按如下方式进行

  • 表示输入的 Tensor 的形状可以是静态的或动态的。当是动态的时,它们由符号描述;此外,这些符号是已支持的,因为我们还知道它们在导出时由用户提供的“真实”示例输入的具体值

  • 运算符的输出形状由其 Fake 实现计算,可以是静态的或动态的。当是动态的时,通常由符号表达式描述。此外

    • 如果输出形状仅依赖于输入形状,则当所有输入形状都是静态的或已支持的动态的时,它要么是静态的,要么是已支持的动态。

    • 另一方面,如果输出形状依赖于输入数据,那么它必然是动态的,而且,因为我们无法知道其具体值,所以它是未支持的

控制流:Guards 和 Assertions#

遇到形状条件时,它要么只涉及静态形状,在这种情况下它是一个 bool,要么涉及动态形状,在这种情况下它是一个符号布尔表达式。对于后者

  • 当条件仅涉及已支持的动态形状时,我们可以使用这些动态形状的具体值将条件评估为TrueFalse。然后,我们可以将一个 guard 添加到形状环境中,声明相应的符号布尔表达式为TrueFalse,然后继续跟踪。

  • 否则,条件涉及未支持的动态形状。通常,我们无法在没有额外信息的情况下评估此类条件;因此,我们无法继续跟踪,并且必须在导出时引发错误。用户应该使用显式的 PyTorch 运算符进行跟踪以继续。这些信息作为 guard 添加到形状环境中,并且还可能帮助评估随后遇到的其他条件为TrueFalse

模型导出后,任何关于已支持动态形状的 guard 都可以被理解为关于输入动态形状的条件。这些条件将针对导出时必须提供的动态形状规范进行验证,该规范描述了不仅示例输入,而且所有未来输入都应满足的动态形状条件,以保证生成代码的正确性。更准确地说,动态形状规范必须逻辑上暗示生成的 guard,否则将在导出时引发错误(并提供对动态形状规范的修复建议)。另一方面,当没有关于已支持动态形状的 guard 时(特别是当所有形状都是静态的时),则无需向导出提供动态形状规范。通常,动态形状规范被转换为生成代码的输入的运行时断言。

最后,任何关于未支持动态形状的 guard 都被转换为“内联”运行时断言。这些断言将添加到生成代码中,在那些未支持动态形状被创建的位置:通常是数据依赖运算符调用之后。

允许的 PyTorch 运算符#

所有 PyTorch 运算符都是允许的。

自定义运算符#

此外,您还可以定义和使用自定义运算符。定义自定义运算符包括为其定义 Fake 实现,就像任何其他 PyTorch 运算符一样(参见上一节)。

这是一个包装 NumPy 的自定义 sin 运算符及其注册的(平凡)Fake 实现的示例。

@torch.library.custom_op("mylib::sin", mutates_args=())
def sin(x: Tensor) -> Tensor:
    x_np = x.numpy()
    y_np = np.sin(x_np)
    return torch.from_numpy(y_np)

@torch.library.register_fake("mylib::sin")
def _(x: Tensor) -> Tensor:
    return torch.empty_like(x)

有时您的自定义运算符的 Fake 实现会涉及依赖数据的形状。以下是自定义 nonzero 的 Fake 实现可能的样子。

...

@torch.library.register_fake("mylib::custom_nonzero")
def _(x):
    nnz = torch.library.get_ctx().new_dynamic_size()
    shape = [nnz, x.dim()]
    return x.new_empty(shape, dtype=torch.int64)

模块状态:读取 vs. 更新#

模块状态包括参数、缓冲区和常规属性。

  • 常规属性可以是任何类型。

  • 另一方面,参数和缓冲区始终是 Tensor。

模块状态可以是动态的或静态的,具体取决于它们的类型(如上所述)。例如,self.training 是一个 bool,这意味着它是静态的;另一方面,任何参数或缓冲区都是动态的。

模块状态中包含的任何 Tensor 的形状都不能是动态的,也就是说,这些形状在导出时是固定的,并且在导出程序的执行之间不能改变。

访问规则#

所有模块状态都必须初始化。访问尚未初始化的模块状态将在导出时引发错误。

读取模块状态始终是允许的.

更新模块状态是可能的,但必须遵循以下规则

  • 一个静态的常规属性(例如,基本类型)可以被更新。读取和更新可以自由地交错,并且如预期,任何读取都将始终看到最新更新的值。由于这些属性是静态的,我们也将烧入值,因此生成的代码将不会有任何实际“获取”或“设置”此类属性的指令。

  • 一个动态的常规属性(例如,Tensor 类型)不能被更新。要做到这一点,它必须在模块初始化期间注册为缓冲区。

  • 缓冲区可以被更新,其中更新可以是就地(例如,self.buffer[:] = ...)或非就地(例如,self.buffer = ...)。

  • 参数不能被更新。通常,参数仅在训练期间更新,而不是在推理期间更新。我们建议使用 torch.no_grad() 进行导出,以避免在导出时更新参数。

函数化的影响#

任何读取和/或更新的动态模块状态将被“提升”(分别)作为生成代码的输入和/或输出。

导出的程序将参数和缓冲区的初始值以及其他 Tensor 属性的常量值与生成的代码一起存储。