评价此页

Dynamo 深度探索#

创建于:2024 年 4 月 2 日 | 最后更新于:2025 年 6 月 10 日

TorchDynamo(或简称 Dynamo)是 torch.compile 中的跟踪器,它往往是那些令人抓狂的回溯的罪魁祸首。然而,我们不能盲目地将这些错误归咎于 Dynamo。为了向用户提供其所具备的灵活性,Dynamo 承担着理解任何 Python 程序的艰巨任务。特别是,Dynamo 必须在内部实现 Python 编程语言的很大一部分!

在这篇文章中,我们将从头开始探讨 Dynamo 的内部设计。我们将讨论它提供的功能以及如何实现这些功能。读完本文后,您将更好地理解当您 torch.compiled 一个 PyTorch 程序时,为什么编译会出错,或者虽然成功但加速效果不如预期。

Dynamo 简介#

在深入了解所有实现细节之前,让我们先讨论一下 Dynamo 的作用。

Dynamo 是一个跟踪器。这意味着,给定一个函数及其输入,它会执行该函数并将指令的线性序列(无控制流)记录到一个图中。例如,考虑以下程序:

import torch

@torch.compile
def mse(x, y):
    z = (x - y) ** 2
    return z.sum()

x = torch.randn(200)
y = torch.randn(200)
mse(x, y)

如果我们将此程序保存到文件 example.py 中,并运行

TORCH_LOGS=graph_code python example.py

我们看到 Dynamo 跟踪的输出:

def forward(l_x_: torch.Tensor, l_y_: torch.Tensor):
    # File: example.py:5, code: z = (x - y) ** 2
    sub = l_x_ - l_y_
    z = sub ** 2
    # File: example.py:6, code: return z.sum()
    sum_1 = z.sum()
    return (sum_1,)

我们将其称为**给定输入的函数图(或跟踪)**。这通过 FX 图 表示。我们简单地将 FX 图视为一个存储函数调用列表的容器。

我们首先要注意的是,该图是 PyTorch 操作的线性序列。1 Dynamo 记录所有 PyTorch 操作并按顺序存储它们。例如,它将 z = (x - y) ** 2 分解为其两个组成操作:sub = l_x_ - l_y_z = sub ** 2

当说跟踪是线性的时,我们指的是没有分支或任何控制流。要了解这一点,请考虑:

import torch

@torch.compile
def fn(x, n):
    y = x ** 2
    if n >= 0:
        return (n + 1) * y
    else:
        return y / n

x = torch.randn(200)
fn(x, 2)

在执行时,带有 TORCH_LOGS=graph_code,返回:

def forward(l_x_: torch.Tensor):
    # File: example.py:5, code: y = x ** 2
    y = l_x_ ** 2
    # File: example.py:7, code: return (n + 1) * y
    mul = 3 * y
    return (mul,)

我们看到 Dynamo 完全从跟踪中删除了 if 语句,只记录了使用输入执行的操作。

因此,应该清楚的是,**函数的跟踪取决于输入**。特别是,这意味着跟踪不是在我们编写 @torch.compile 时生成的,而是在我们使用实际参数执行函数 fn(x, 2) 时生成的。

这里要注意的另一个有趣的事情是 Dynamo 删除了函数的第二个参数。相反,它将其视为一个常量,并在图中记录了操作 n + 1 的结果。这是 Dynamo 的另一个特性:Dynamo 会将任何非张量值视为常量……除了整数。现在让我们看看整数有什么特别之处。

Dynamo 的最后一个定义属性是它知道如何处理动态形状。符号形状指的是 Dynamo 跟踪形状的能力,更普遍地说,是跟踪整数而不是将它们保留为常量。这可以避免重新编译并部署适用于生产中任何大小的通用模型。动态形状出现的主要例子是批处理大小,我们可能使用固定的批处理大小训练模型,但随后对任意批处理大小执行推理,或者是在处理文本或音频时遇到的可变序列长度。

我们可以通过多次执行上述示例来看到这一点:

import torch

@torch.compile
def fn(x, n):
    y = x ** 2
    if n >= 0:
        return (n + 1) * y
    else:
        return y / n

x = torch.randn(200)
fn(x, 2)
fn(x, 3)
fn(x, -2)

在这种情况下,TORCH_LOGS=graph_code 生成了另外两个图:

# Graph for n==2 omitted

def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
    # File: a.py:5, code: y = x ** 2
    y = l_x_ ** 2

    # File: a.py:7, code: return (n + 1) * y
    add = l_n_ + 1
    mul = add * y
    return (mul,)
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
    # File: a.py:5, code: y = x ** 2
    y = l_x_ ** 2

    # File: a.py:9, code: return y / n
    truediv = y / l_n_
    return (truediv,)

Dynamo 检测到第一个调用后有一个整数改变了它的值,并开始跟踪它。我们看到这些图是通用的,通过一个 SymInt 类型的对象象征性地跟踪变量 n

如果在这些调用之后我们调用 fn(x, 4),Dynamo 将不会重新编译,而是重用已经跟踪的图。

总结一下:1. Dynamo 是一个 Python 跟踪器。2. 给定一些输入,它会返回一个包含已执行的 PyTorch 函数的 FX 图。3. 如果检测到整数在调用之间发生变化,它还可以跟踪整数。4. 它会专门处理任何非张量或标量的值。

当然,Dynamo 还有很多其他功能,例如找出何时需要重新跟踪、重写函数的字节码、实现图中断等等。为了保持介绍的简洁,我们将在后续逐步讨论所有这些内容。

PEP 523:向 CPython 添加帧评估 API#

现在假设我们被赋予实现 Dynamo 的任务。我们该从何开始呢?对我们来说,非常方便的是,PEP 523 随 Python 3.6 发布。该 PEP 旨在允许第三方为 Python 创建 JIT 编译器。让我们看看如何做到这一点。

关于 CPython 的注意事项:CPython 在内部实现为堆栈机器。Python 程序被编译成字节码,然后由这个解释器执行。要了解更多关于这些字节码的信息,请参阅标准库中的dis 模块。另请参阅开发者文档,以了解 CPython 解释器的介绍。我们将假设读者熟悉堆栈机器的概念。

PEP 523 暴露了一个 API,用户可以添加一个自定义的每函数解释器。然后,CPython 将使用这个解释器而不是它自己的解释器来执行函数。为了能够执行函数,在进入时,CPython 向自定义解释器提供诸如以下信息:- 函数的字节码 - 函数参数的值(即局部变量)及其名称 - 全局变量的值及其名称 - 内置函数,如 absprint

您可以在此处查看所有字段。2

总之,CPython 为用户的解释器提供了执行函数所需的所有信息。3

通过这个 API,我们可以通过实现一个解释器来运行代码并记录在此执行期间发生的所有 PyTorch 操作到一个图中,从而实现一个跟踪器。这正是 Dynamo 所做的。

Dynamo 使用这个 CPython API 来解析所有这些对象,并将它们打包到一个 Python 结构中。完成这些后……它又从 C 返回到 Python。除了这部分与 CPython 通信的代码外,Dynamo 完全用 Python 实现。

应该清楚的是,装饰器 @torch.compile 的作用是安装必要的脚手架,以便在调用函数时将字节码、参数、全局变量等传递给 Dynamo。同样,@torch.compile 实际上并没有编译任何东西。

用 Python 实现 CPython#

所以,我们回到了 Python 世界。我们有了函数的字节码,以及执行它所需的所有上下文。特别是,我们已经到达了 _convert_frame_assert。这就是装饰器 torch.compile 返回的函数!我们从 _dynamo.optimize 获得了这个函数。装饰器 torch.compile 只是 _dynamo.optimize 周围的一个友好 API。

在开始实现 Python 解释器之前,我们想定义一个IR。特别是,我们想将所有局部变量和全局变量封装到我们自己的内部类中。这使我们能够更好地跟踪这些对象,并将 Dynamo 视为相同方式的对象分组在一起。

内部类结构的父类是 VariableTracker,它表示 Dynamo 理解的不同对象。例如,ListVariable 表示一个 list 对象,并在内部维护一个VariableTrackers 列表。另一个 VariableTracker 的例子是 ConstantVariable。ConstantVariable 封装了 Dynamo 视为常量的所有对象。我们还有针对需要特殊注意的对象的特殊子类,例如 TensorVariable。所有这些内部类都定义在 torch/_dynamo/variables 文件夹中。

Python 对象在 VariableBuilder._wrap 中被封装到其相应的 VariableTracker 类中。这个函数只是一个非常长的 elif 链,它尝试将 Python 输入递归地模式匹配到适当的 VariableTracker 类型。

调试提示。当我们从 Dynamo 得到意外结果时,有时是由于构建器造成的。如果构建器的逻辑错误,有时 Dynamo 可能会将变量包装到不正确的 VariableTracker 类型中,这可能会在以后导致问题。查看错误中出现的 VariableTracker 类型以及当您遇到 Dynamo 错误时抛出异常的 VariableTracker 方法非常有用。特别是,有时我们发现一个对象被跟踪为 UserDefinedObjectVariable(这是 Dynamo 的包罗万象的类),而它应该被跟踪为更具体的东西。在这种情况下,SourceBuilder.__call__ 逻辑通常是罪魁祸首。

调试提示。使用 TORCH_LOGS=dynamo 运行程序时,其中一个输出的工件是以下形式的行:

TRACE LOAD_GLOBAL y [TorchInGraphFunctionVariable(<built-in method any>), TensorVariable()]

这是原始程序的字节码和该点的堆栈状态。这对于查找对象未跟踪到正确的 VariableTracker 中的位置非常有用。

好的,我们有一个用于跟踪器的 IR,现在我们*只需*重新实现 CPython 的堆栈机器。这由 InstructorTranslatorBasesymbolic_convert.py 中实现。

InstructionTranslatorBase 拥有大约 200 个方法,实现了几乎所有 Python 字节码。例如,我们可以看到 BUILD_LIST 的实现:

def BUILD_LIST(self, inst):
    items = self.popn(inst.argval)
    self.push(ListVariable(items, mutation_type=ValueMutationNew()))

这就是诸如 l = [2, 3, 4] 之类的构造所生成的字节码。在这种情况下,由于有三个元素,生成的字节码是 BUILD_LIST 3。这意味着我们从栈顶弹出 3 个元素,并将由这三个元素组成的新列表对象推送到栈顶。

生成输出图#

有了符号执行 Python 代码的方法,我们就可以提取在给定某些输入的情况下,程序符号执行期间发生的 PyTorch 操作。这在 Dynamo 中通过 OutputGraph 对象实现。OutputGraph 对象 绑定到一个 `InstructionTranslator` 对象,它跟踪创建将由 Dynamo 返回的 FX 图所需的所有数据。

FX 图的所有输入和中间元素都是 fx.Node。在 Dynamo 中,fx.Node 被封装在 fx.Proxy 中。fx.Proxy 用于构建 FX 图。特别是,它们记录在其上执行的每个 PyTorch 操作到图中。您可以通过调用 create_proxy 来创建要添加到图中的新操作。然后,我们可以通过函数 wrap_fx_proxy 将其添加到图中。

图存储张量上的操作……以及符号整数上的操作。我们稍后将讨论符号整数,但首先我们将讨论 Dynamo 如何解决一个相当重要的正确性问题。

使 Dynamo 健全:守卫#

此时,我们有一种完全不考虑控制流来跟踪程序的方法。为此,我们重新实现了 CPython 的所有功能……如果这听起来有点矫枉过正,那是因为它确实如此。torch.jit.trace 已经实现了这一点,而无需所有这些机制,那么这是为什么呢?

torch.jit.trace 的问题,正如其文档中警告的那样,是它仅在被跟踪的程序不依赖于数据时才有效。换句话说,它仅在程序本身是线性的情况下才有效。这意味着编写我们的程序时不能使用 if-else、for-while 循环、异常。更甚的是,我们使用的任何库都不能使用任何控制流!总而言之,在像 Python 这样动态的语言中不使用控制流实际上是一个巨大的限制。

JAX 通过在回溯后始终回溯和缓存图来解决这个问题。另一方面,Dynamo 使用守卫来避免每次都回溯整个程序。

守卫是为了将框架专门用于一组示例输入而做出的假设(对输入的布尔表达式)。仅当这些假设在新输入上成立时,重用图才有效。

例如,函数的任何常量输入,例如字符串,都会安装一个守卫,说明该输入应为 str 类型,并等于我们传入的字符串。运行

import torch

@torch.compile
def fn(a, b):
    return a * len(b)

fn(torch.arange(10), "Hello")

使用 TORCH_LOGS=guards 会打印(除其他守卫外)

___check_type_id(L['b'], 94334122025024)
L['b'] == 'Hello'

这表示“局部变量 b 应该具有特定类型(本例中为 str,由常量 9433... 表示),并且其值应为 'Hello'”。如果我们再次执行函数并传入不同的参数

import torch

@torch.compile
def fn(a, b):
    return a * len(b)

fn(torch.arange(10), "Hello")
fn(torch.arange(10), "Hi")

我们可以通过运行 TORCH_LOGS=recompiles 查看失败的守卫

Recompiling function fn in script.py:3
triggered by the following guard failure(s):
     - L['b'] == 'Hello'

构建器中包装函数输入程序执行期间,守卫会被累积。我们将在下一节中展示更多守卫的例子,但首先让我们讨论源。

跟踪如何从进入当前帧时存在的原始局部或全局变量中重建变量。特别是,它跟踪原始局部和全局对象以及它们包含的任何对象。在

def foo(x: Tensor, y: List[Tensor]):
    a = x * y[0]
    return a * x

xy 具有 LocalSource 作为它们的源,而 y[0] 具有 GetItemSource,它内部存储一个 LocalSource。另一方面,a 将没有源,因为它是一个只存在于 fx 图中的中间变量。

所有这些都定义在 torch/_dynamo/source.py 中。我们可以在下面的示例中看到由 GetItemSource 生成的守卫

import torch

@torch.compile
def fn(x, l):
    return x * len(l[0])

fn(torch.randn(8), ["Hi", "Hello"])

生成以下守卫

___check_type_id(L['l'], 94439025877664)
len(L['l']) == 2
___check_type_id(L['l'][0], 94439025840192)
L['l'][0] == 'Hi'
___check_type_id(L['l'][1], 94439025840192)
L['l'][1] == 'Hello'

在这里,我们看到由 GetItemSource ([0][1]) 生成的代码包装了 LocalSource (L['l'])。

此时,有了源和守卫,我们能够实现一个缓存系统,以避免每次都重新编译,而无需每次都重新跟踪。我们将在后续更详细地讨论这个缓存系统。

细心的读者可能已经注意到,这并不能解释为什么我们需要对 Python 解释器进行如此精细的控制,以至于不得不重新实现它。我们展示的守卫示例取决于输入对象,因此我们仍然可以在执行函数之前计算这些守卫。换句话说,我们可以在 torch.jit.trace 的基础上实现这个守卫系统,并以更少的精力获得相同的功能……进入符号形状。

符号形状#

我们在引言中讨论的另一点是 Dynamo 知道如何跟踪整数。为了实现这一点,我们使用一个符号类 torch.SymInt,它表现得像一个 int,但它将对其执行的所有操作记录在输出 FX 图中。4 我们在引言中引入符号整数跟踪时已经看到了这个类。

现在让我们讨论定义 Dynamo 中符号形状跟踪的三个属性,以及如何实现它们。

默认静态#

Dynamo 默认假定每个整数,无论是输入还是张量的形状,都是静态的。换句话说,函数首次执行时不会跟踪任何整数。然后,只有当它检测到整数或形状在执行期间改变值时,它才会跟踪它并生成一个在该变量上通用的图。

我们已经在介绍中使用整数看到了这种行为。现在让我们看一个使用张量形状的例子。

import torch

@torch.compile
def fn(a, b):
    return a.shape[0] * a * b

fn(torch.randn(4, 3), torch.randn(4, 3))
fn(torch.randn(8, 3), torch.randn(8, 3))

使用 TORCH_LOGS=graph_code 运行此程序,我们看到这两个调用被跟踪为

def forward(self, l_a_: torch.Tensor, l_b_: torch.Tensor):
    mul = 4 * l_a_
    mul_1 = mul * l_b_
    return (mul_1,)

def forward(self, s0: torch.SymInt, l_a_: torch.Tensor, l_b_: torch.Tensor):
    size = l_a_.size()
    getitem = size[0]
    mul = getitem * l_a_
    mul_1 = mul * l_b_
    return (mul_1,)

在第一个图中,形状被跟踪为常量,但一旦它发生变化,它就会使用 SymInt 符号化地跟踪它。通常,查看中间值形状的更简单方法是使用 TORCH_LOGS=graph_sizes 运行程序

TRACED GRAPH TENSOR SIZES
===== __compiled_fn_1 =====
l_a_: (s0, 3)
l_a_ (concrete): (8, 3)
l_b_: (s0, 3)
l_b_ (concrete): (8, 3)
mul: (s0, 3)
mul (concrete): (8, 3)
mul_1: (s0, 3)
mul_1 (concrete): (8, 3)

我们可以看到,两个张量参数的第一个维度是动态的,因为它由 s0 变量表示。

我们可以通过运行 TORCH_LOGS=guards 了解 Dynamo 如何实现这一点

# Guards first call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])

# Guards second call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])

L['b'].size()[0] == L['a'].size()[0]
2 <= L['a'].size()[0]

我们看到在第一次调用时,守卫检查张量是否具有固定的尺寸和步幅。这些守卫在第二次执行中失败,因此它会重新跟踪。由于是 int 守卫失败,在第二次迭代中,它会符号化地跟踪这个 int,并在更通用的内核上安装更通用的守卫。

编译性能提示。如果您知道某个维度的大小会变化,您可以在调用 torch.compile 之前通过调用 torch._dynamo.mark_dynamic 将其标记为动态。这将避免首次以静态形状编译。还有其他有用的实用函数,例如 maybe_mark_dynamicmark_static。您还可以通过调用 torch.compile(dynamic=True) 来跟踪所有整数和形状。这主要用于调试目的。

0, 1 始终特殊化#

无论我们是否将某个维度标记为动态,如果传入的输入中该维度为 0 或 1,Dynamo 将其跟踪为非动态,并为其生成一个特定的图。这就是为什么在上面的示例中我们发现形式为 2 <= L['a'].size()[0] 的守卫。

这个选择有几个原因。其中有两个特别重要——张量为空当且仅当它的任何一个维度为零——张量只有在其中一个步长为一时才能是连续的

此策略决策不适用于普通的 Python 整数;如果我们认为 Python 整数应该动态编译,我们默认不会对其进行特殊化;相反,它是否被特殊化取决于其用途。

鸭子造型#

Dynamo 执行我们称之为“鸭子造型”的操作。如果在跟踪时两个动态整数具有相同的值,我们将假定它们相等并对其进行守卫。实际上,这意味着在上面的示例中,我们不是拥有两个符号 s0, s1,而是将它们统一为 s0 并设置了守卫 L['b'].size()[0] == L['a'].size()[0]。这使得编译器内部能够执行融合,同时能够生成足够通用的内核。

符号整数上的守卫#

我们现在从高层次上理解了符号形状是如何实现的以及它们具有的属性。那么,为什么符号形状迫使我们走上控制 CPython 解释器的棘手路线呢?考虑以下示例

import torch

@torch.compile(dynamic=True)
def fn(a):
    if a.shape[0] * 2 < 16:
        return a
    else:
        return a + 1

fn(torch.randn(8))

这段代码有一个形式为 2*L['a'].size()[0] >= 16 的守卫。这是一个关于函数输入而言非平凡的守卫,但它是在程序执行过程中注册的。更重要的是,我们无法知道是否需要这个守卫,直到我们看到基于 SymNodeVariable 参数的 if 语句条件。这些条件对于 torch.jit.trace 是不可见的,并且需要对 python 代码进行深入分析。

调试提示 运行此代码时使用 TORCH_LOGS=dynamo 会告诉我们此守卫的添加位置

eval 2*s0 >= 16 [guard added] at script.py:5 in fn (_dynamo/variables/tensor.py:812 in evaluate_expr)

在此处设置断点并查看回溯对于理解守卫的来源非常有帮助。

使 Dynamo 完整:图中断#

借助我们讨论过的所有工具,我们拥有一个可以跟踪张量和整数上的 PyTorch 操作的跟踪器,并且它有一个缓存系统,知道何时可以重用以前跟踪过的图以及何时需要重新跟踪。所有这一切都在执行任意 Python 代码!

这只存在一个小问题。“执行任意 Python 代码”这个说法可能有点过于笼统了。Dynamo 实现了一大部分 Python,但它是否实现了更复杂的部分,例如协程或异步?它是否实现了整个 Python 标准库?NumPy 也有一个 Python API。那么 torch.compile 也理解 NumPy 吗?还有 Django?5

Python 生态系统庞大,其中很大一部分是用 C++ 或 Rust 等性能更高的语言编写的,它只是公开了 Python 绑定。Dynamo 无法跟踪通过 C++ 实现的 Python 对象。当跟踪器遇到它无法理解的操作时,它能做什么?

机器学习跟踪器处理此问题的常用方法是通知用户它们遇到的操作,并完全放弃跟踪。这在 PyTorch 的情况下会带来真正的可用性问题,因为它的用户习惯了它提供的灵活性。一个真实的例子是 doctr_det_predictor 模型使用 NumPy 和 cv2 库来 后处理模型结果

这是另一个访问 CPython 有趣的地方。Dynamo 不会出错,而是可以让 CPython 运行有问题的代码!为此,Dynamo 在跟踪时生成一个包含有问题的代码之前所有操作的图,以及一个包含之后所有操作的图。6 然后,在运行时,它将委托 CPython 执行第一个图,然后是有问题的代码,然后是第二个图。这种停止跟踪并生成多个图的过程称为图中断

一个小小的忏悔:我在整个引言和前几节中都说了谎。Dynamo 不会生成一个图,而是会生成多个图!从所有实际目的来看,在第二个图之后重新开始跟踪可以被视为开始跟踪一个新函数。图中断之后的新图将拥有自己的守卫、新的局部变量集等等。

为了讨论如何实现图中断,我们首先需要回顾 Dynamo 如何与 CPython 交互。使用 PEP 523,CPython 允许用户使用自己的帧评估机制。我们之前没有讨论的是,CPython 也公开了自己的帧评估供其他人使用。Dynamo 利用这一点,让快速的 CPython 解释器运行编译后的代码。对于没有图中断的函数,一个程序调用函数 2 次且使用相同参数的整个跟踪/执行过程如下所示

  1. 第一次调用函数时

    1. Dynamo 将函数跟踪为 FX 图

      1. FX 图由编译器(Inductor)编译成高效的低级代码……但这又是另一回事了

    2. 它重写函数的字节码,使其只调用编译后的函数

    3. 它将此新字节码交给 CPython 并要求它 在此处 运行它

  2. 第二次调用函数时

    1. 它将第一次调用中的守卫与新参数进行比较这里。由于它们与之前是相同的参数,因此它们通过

    2. 它要求 CPython 运行与这些守卫关联的字节码 此处

这个过程本身看起来过于复杂。为什么生成新的字节码并让 CPython 运行它,而不是简单地为编译函数创建 C++ 绑定并执行它呢?嗯,这种模式允许我们实现图中断!图中断生成的字节码具有以下结构

  1. 执行第一个图的字节码

  2. 使栈保持与 CPython 执行第一个图后相同的状态的字节码。它还回放此时可见的对局部或全局变量的任何修改

  3. 导致 Dynamo 图中断的字节码

  4. 执行第二个图的字节码

让我们看一个简单的例子

import torch

@torch.compile
def fn(a):
    b = a + 2
    print("Hi")
    return b + a

fn(torch.randn(4))

TORCH_LOGS=bytecode 运行此命令,将显示初始字节码和修改后的字节码

MODIFIED BYTECODE fn script.py line 3
 0 LOAD_GLOBAL              1 (__compiled_fn_0)
 2 LOAD_FAST                0 (a)
 4 CALL_FUNCTION            1
 6 STORE_FAST               3 (graph_out_0)
 8 LOAD_GLOBAL              0 (print)
10 LOAD_CONST               2 ('Hi')
12 LOAD_FAST                3 (graph_out_0)
14 LOAD_CONST               3 (0)
16 BINARY_SUBSCR
18 STORE_FAST               1 (b)

20 CALL_FUNCTION            1
22 LOAD_GLOBAL              2 (__resume_at_14_1)
24 ROT_TWO
26 LOAD_FAST                0 (a)
28 LOAD_FAST                1 (b)
30 CALL_FUNCTION            3
32 RETURN_VALUE

MODIFIED BYTECODE resume_in_fn script.py line 6
 0 LOAD_GLOBAL              1 (__compiled_fn_2)
 2 LOAD_FAST                2 (b)
 4 LOAD_FAST                1 (a)
 6 CALL_FUNCTION            2
 8 UNPACK_SEQUENCE          1
10 RETURN_VALUE

我们可以看到修改后的字节码被分成两个函数,fn 是原始函数,另一个函数名为 resume_in_fn。第二个函数是 Dynamo 为实现从图中断开始的程序执行而创建的函数。这通常被称为延续函数。这个延续函数只是用正确的参数调用第二个编译后的函数。初始函数的代码被重写,实现了我们之前描述的策略

  • L0-4. 调用编译后的函数 (a + 2)。

  • L6. 将其结果存储在名为 graph_out_0 的局部变量中。graph_out_0 是一个元组

  • L8-18. 将堆栈保留在图中断时的状态

  • L20. 执行导致图中断的代码

  • L22-32. 调用编译后的续流函数 (a + b)

Dynamo 中栈的代码生成委托给 VariableTracker 子类。Dynamo 中的每个 VariableTracker 对象都有一个 reconstruct 方法,用于生成必要的字节码以在栈上创建它所代表的 Python 对象。

调试提示。图中断会影响性能,因此最好避免它们。使用 TORCH_LOGS=graph_breaks 运行程序是查找程序命中多少图中断的好方法。它返回的信息是 VariableTracker 对象的形式,因此上面的调试提示有时也有助于找出导致图中断的原因。

结论#

Dynamo 是一款复杂的软件。一旦你开始实现一个 CPython 解释器,你就知道你将经历一段旅程。尽管如此,我们希望这篇文章能帮助揭开它的一些神秘面纱。

Dynamo(大部分)是用 Python 实现的。我们留下了许多指向我们讨论过的代码片段的链接。我们希望阅读这些代码片段并搜索调用它们的地方,或者在它们上面设置断点并查看调用栈,有助于理解其余的代码库。

当然,学习软件工作原理的最佳方式是扩展它。在这种情况下,最好的方法是查看 github 上的开放 Dynamo 问题。其中许多问题只需要对代码进行非常小的更改,一旦你找到需要进行这些更改的地方。

脚注#

以下是本文档中提及概念的附加详细信息和参考。


1

在文献中,这被称为有向无环图(DAG)。

2

所有这些绑定代码都在 torch/csrc/dynamo/eval_frame.c 中。

3

在 CPython 术语中,所有这些对象的集合称为

4

还有 SymBoolSymFloat 类。后者在撰写本文时使用不多。

5

有趣的是,它确实理解 NumPy 代码!请查看 这篇博文文档。现在,这之所以可能,是因为我们使用 PyTorch 重新实现了 NumPy。不过,在 PyTorch 中实现 Django 可就难了……

6

假设只有一段有问题的代码。如果还有更多,Dynamo 可以将代码分成它需要的任意多个图。