Dynamo 深度解析#
创建时间:2024 年 4 月 2 日 | 最后更新时间:2025 年 8 月 12 日
torch.compile
中的追踪器是 TorchDynamo(或简称 Dynamo),它通常是那些令人费解的堆栈回溯的“罪魁祸首”。然而,我们不能一概而论地将错误归咎于 Dynamo。为了给用户提供所需的灵活性,Dynamo 肩负着理解任何 Python 程序的艰巨任务。特别是,Dynamo 需要在内部实现相当一部分 Python 语言的功能!
在本文中,我们将从头开始介绍 Dynamo 的内部设计。我们将讨论它提供的功能以及它的实现方式。阅读本文后,您将更深入地了解当您 torch.compiled
一个 PyTorch 程序并遇到编译错误时,或者编译成功但速度提升不如预期时,到底发生了什么。
Dynamo 入门指南#
在深入探讨所有实现细节之前,让我们先讨论一下 Dynamo 的作用。
Dynamo 是一个追踪器。这意味着,给定一个函数及其输入,它会执行该函数并将指令的线性序列(不包含控制流)记录到一个图中。例如,考虑以下程序
import torch
@torch.compile
def mse(x, y):
z = (x - y) ** 2
return z.sum()
x = torch.randn(200)
y = torch.randn(200)
mse(x, y)
如果我们将此程序保存到文件 example.py
并运行
TORCH_LOGS=graph_code python example.py
我们将看到 Dynamo 追踪的输出
def forward(l_x_: torch.Tensor, l_y_: torch.Tensor):
# File: example.py:5, code: z = (x - y) ** 2
sub = l_x_ - l_y_
z = sub ** 2
# File: example.py:6, code: return z.sum()
sum_1 = z.sum()
return (sum_1,)
我们称之为给定输入的函数的图(或追踪)。这通过 FX graph 表示。我们将 FX graph 简单地视为一个存储函数调用列表的容器。
我们首先注意到的是,该图是 PyTorch 操作的线性序列。1 Dynamo 记录所有 PyTorch 操作并将它们按顺序存储。例如,它将 z = (x - y) ** 2
分解为其两个组成操作 sub = l_x_ - l_y_
和 z = sub ** 2
。
当我们说追踪是线性的时,我们指的是没有分支或任何控制流。要理解这一点,请考虑
import torch
@torch.compile
def fn(x, n):
y = x ** 2
if n >= 0:
return (n + 1) * y
else:
return y / n
x = torch.randn(200)
fn(x, 2)
当使用 TORCH_LOGS=graph_code
执行时,它会返回
def forward(l_x_: torch.Tensor):
# File: example.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: example.py:7, code: return (n + 1) * y
mul = 3 * y
return (mul,)
我们看到 Dynamo 完全从追踪中移除了 if
语句,只记录了使用输入执行的操作。
因此,应该清楚的是,函数的追踪取决于输入。特别地,这意味着当我们在 @torch.compile
中编写代码时,追踪并不会生成,而是在使用实际参数执行函数 fn(x, 2)
时生成。
另一个值得注意的有趣之处是 Dynamo 移除了函数的第二个参数。相反,它将其视为常量,并在图中记录操作 n + 1
的结果。这是 Dynamo 的另一个特性:Dynamo 会将任何非张量值视为常量…除了整数。现在让我们看看整数是如何特殊的。
Dynamo 的最后一个定义属性是它知道如何处理动态形状。符号形状是指 Dynamo 追踪形状的能力,更广泛地说,是追踪整数而不是将它们保留为常量。这使得可以避免重新编译,并在生产环境中部署通用的、适用于任何大小的模型。动态形状出现的主要例子是批量大小(batch size),我们可以用固定的批量大小训练模型,然后在推理时使用任意批量大小,或者是在处理文本或音频时遇到的可变序列长度。
我们可以通过多次执行上面的示例来看到这一点
import torch
@torch.compile
def fn(x, n):
y = x ** 2
if n >= 0:
return (n + 1) * y
else:
return y / n
x = torch.randn(200)
fn(x, 2)
fn(x, 3)
fn(x, -2)
在这种情况下,TORCH_LOGS=graph_code
生成了另外两个图
# Graph for n==2 omitted
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
# File: a.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: a.py:7, code: return (n + 1) * y
add = l_n_ + 1
mul = add * y
return (mul,)
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
# File: a.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: a.py:9, code: return y / n
truediv = y / l_n_
return (truediv,)
Dynamo 检测到一个整数在其第一次调用后改变了值,并开始追踪它。我们看到这些图是通用的,并通过 SymInt
类型的一个对象来符号化地追踪变量 n
。
如果在这些调用之后调用 fn(x, 4)
,Dynamo 不会重新编译,而是会重用已追踪的图。
总结:1. Dynamo 是一个 Python 追踪器 2. 给定一些输入,它返回一个包含执行的 PyTorch 函数的 FX 图 3. 如果它检测到整数在调用之间发生变化,它也可以追踪整数 4. 它会专门化任何非张量或标量值
当然,Dynamo 还做了很多其他事情,比如确定何时需要重新追踪、重写函数的字节码、实现图中断…为了保持介绍简短,我们将在后续文章中逐步讨论所有这些内容。
PEP 523:为 CPython 添加帧评估 API#
现在,想象一下我们的任务是实现 Dynamo。我们从哪里开始呢?恰好,PEP 523 在 Python 3.6 中发布了。该 PEP 旨在 让第三方能够创建 Python 的 JIT 编译器。让我们看看它是如何实现的。
关于 CPython 的说明:CPython 在内部实现为一个栈式虚拟机。Python 程序被编译成字节码,然后由解释器执行。要了解更多关于这些字节码的信息,请参阅标准库中的 dis 模块。另请参阅 开发者文档 以了解 CPython 解释器的介绍。我们假设读者熟悉栈式虚拟机这个概念。
PEP 523 公开了一个 API,用户可以通过该 API 添加自定义的每个函数的解释器。然后,CPython 将使用该解释器而不是其自身的解释器来执行函数。为了能够执行函数,在进入时,CPython 会向自定义解释器提供诸如以下内容:- 函数的字节码 - 函数参数的值(即局部变量)及其名称 - 全局变量的值及其名称 - 内置函数,如 abs
或 print
总之,CPython 向用户的解释器提供了执行函数所需的所有信息。3
有了这个 API,我们可以通过实现一个解释器来创建一个追踪器,该解释器运行代码并将所有在执行过程中发生的 PyTorch 操作记录到一个图中。这正是 Dynamo 的做法。
Dynamo 使用这个 CPython API 来解析所有这些对象,并将它们打包成一个 Python 结构。完成之后…它从 C 回到 Python。除了这段与 CPython 通信的代码之外,Dynamo 完全是用 Python 实现的。
应该清楚的是,装饰器 @torch.compile
的工作是安装必要的脚手架,以便在调用函数时将字节码、参数、全局变量等传递给 Dynamo。同样,@torch.compile
实际上并没有编译任何东西。
用 Python 实现 CPython#
所以,我们回到了 Python 世界。我们拥有函数的字节码以及执行它所需的所有上下文。特别是,我们来到了 _convert_frame_assert。这是装饰器 torch.compile
返回的函数!我们从 _dynamo.optimize 到达此函数。装饰器 torch.compile
只是 _dynamo.optimize
的一个方便的 API。
在开始实现 Python 解释器之前,我们想定义一个中间表示 (IR)。特别是,我们想将所有局部变量和全局变量包装在我们自己的内部类中。这使我们能够更好地跟踪这些对象,并将可以以相同方式处理的对象组合起来,以便 Dynamo 识别。
内部类结构中的父类是 VariableTracker
,它代表 Dynamo 理解的不同对象。例如,ListVariable
代表一个 list
对象,并在内部维护一个 VariableTrackers 列表。另一个 VariableTracker
的例子是 ConstantVariable。ConstantVariable 包装了所有 Dynamo 认为常量的对象。我们还有专门的子类来处理需要特殊关注的对象,例如 TensorVariable。所有这些内部类都定义在 torch/_dynamo/variables 文件夹中。
Python 对象被包装到其对应的 VariableTracker
类中,在 VariableBuilder._wrap 中。这个函数只是一个非常长的 elif
链,它试图递归地将 Python 输入模式匹配到适当的 VariableTracker
类型。
调试技巧。当 Dynamo 产生非预期结果时,有时是由于构建器引起的。如果构建器的逻辑错误,有时 Dynamo 可能会将变量包装到错误的 VariableTracker
类型中,这可能会在后续导致问题。在遇到 Dynamo 错误时,查看错误中出现的 VariableTracker
类型以及抛出异常的 VariableTracker
方法非常有帮助。特别是,有时我们会发现一个对象被追踪为 UserDefinedObjectVariable
(这是 Dynamo 的通用类),而它应该被追踪为更具体的类型。在这些情况下,VariableBuilder
的逻辑通常是罪魁祸首。
调试技巧。当使用 TORCH_LOGS=dynamo
运行程序时,打印出的一个伪像是这样的行:
TRACE LOAD_GLOBAL y [TorchInGraphFunctionVariable(<built-in method any>), TensorVariable()]
这是原始程序的字节码以及当时堆的状态。这对于查找对象未被正确追踪到 VariableTracker
中的位置非常有用。
好了,我们有了一个追踪器的 IR,现在我们只需要重新实现 CPython 的栈式虚拟机。这由 InstructorTranslatorBase 在 symbolic_convert.py 中实现。
InstructionTranslatorBase
拥有大约 200 个方法,实现了几乎所有的 Python 字节码。例如,我们可以看到 BUILD_LIST
的实现
def BUILD_LIST(self, inst):
items = self.popn(inst.argval)
self.push(ListVariable(items, mutation_type=ValueMutationNew()))
这是由 l = [2, 3, 4]
等构造生成的字节码。在这种情况下,由于有三个元素,生成的字节码是 BUILD_LIST 3
。这意味着我们将堆栈顶部的 3
个元素弹出,并将由这三个元素组成的新列表对象推到堆栈顶部。
生成输出图#
通过一种符号化执行 Python 代码的方法,我们可以提取在给定输入下对程序进行符号化执行期间发生的 PyTorch 操作。这在 Dynamo 中通过 OutputGraph 对象实现。OutputGraph
对象绑定到 InstructionTranslator 对象,并跟踪创建 Dynamo 返回的 FX 图所需的所有数据。
FX 图的所有输入和中间元素都是 fx.Node
。在 Dynamo 中,fx.Node
被包装在 fx.Proxy
中。fx.Proxy
用于构建 FX 图。特别是,它们会将对它们执行的每个 PyTorch 操作记录到图中。您可以创建一个新的操作并通过调用 create_proxy 来添加到图中。然后,我们可以通过函数 wrap_fx_proxy 将其添加到图中。
一个图存储对张量的操作…以及对符号整数的操作。我们稍后将讨论符号整数,但首先我们将讨论 Dynamo 如何解决一个相当重要的正确性问题。
使 Dynamo 正确:Guard#
此时,我们有了一种方法可以完全忽略控制流来追踪程序。为此,我们重新实现了所有 CPython…如果这听起来有点杀鸡焉用牛刀,那是因为它确实是。torch.jit.trace 已经实现了这一点,而无需所有这些机制,那么 Dynamo 有何优势?
正如其文档中所警告的,torch.jit.trace
的问题在于,只有当追踪的程序不是数据依赖的时,它才能正常工作。换句话说,如果程序本身是线性的,它就能工作。这意味着编写程序时不能使用 if-else、for-while 循环、异常。更重要的是,我们使用的任何库都不能使用任何控制流!总而言之,在一个像 Python 这样动态的语言中不使用控制流,实际上是一个巨大的限制。
JAX 通过始终重新追踪并在重新追踪后缓存图来解决这个问题。另一方面,Dynamo 使用 Guard 来避免每次都重新追踪整个程序。
Guard 是一个假设(关于输入的布尔表达式),为了使一个帧专门化为一组示例输入而做的。重用图只有在这些假设在新输入上成立时才有效。
例如,函数中的任何常量输入(如字符串)都会安装一个 Guard,声明该输入应为 str
类型,并且等于我们传入的字符串。运行
import torch
@torch.compile
def fn(a, b):
return a * len(b)
fn(torch.arange(10), "Hello")
并使用 TORCH_LOGS=guards
打印(除其他 Guard 外)
___check_type_id(L['b'], 94334122025024)
L['b'] == 'Hello'
这可以解读为“局部变量 b
应该具有特定的类型(在这种情况下为 str
,由常量 9433...
表示),并且其值应为 'Hello'
”。如果我们然后再次执行该函数并传入一个不同的参数
import torch
@torch.compile
def fn(a, b):
return a * len(b)
fn(torch.arange(10), "Hello")
fn(torch.arange(10), "Hi")
通过运行 TORCH_LOGS=recompiles
,我们可以看到失败的 Guard
Recompiling function fn in script.py:3
triggered by the following guard failure(s):
- L['b'] == 'Hello'
Guard 在函数输入被包装在构建器中以及程序执行期间累积。我们将在下一节中展示更多 Guard 的示例,但首先让我们讨论 Source。
Source 跟踪如何从进入当前帧时存在的原始局部变量或全局变量中重建变量。特别是,它跟踪原始局部变量和全局变量以及它们包含的任何对象。在
def foo(x: Tensor, y: List[Tensor]):
a = x * y[0]
return a * x
x
和 y
的 Source 是 LocalSource,而 y[0]
的 Source 是 GetItemSource,它在内部存储了一个 LocalSource
。另一方面,a
将没有 Source,因为它是仅存在于 fx 图中的中间变量。
所有这些都在 torch/_dynamo/source.py 中定义。我们可以在以下示例中看到由 GetItemSource
生成的 Guard
import torch
@torch.compile
def fn(x, l):
return x * len(l[0])
fn(torch.randn(8), ["Hi", "Hello"])
生成以下 Guard
___check_type_id(L['l'], 94439025877664)
len(L['l']) == 2
___check_type_id(L['l'][0], 94439025840192)
L['l'][0] == 'Hi'
___check_type_id(L['l'][1], 94439025840192)
L['l'][1] == 'Hello'
在这里,我们看到由 GetItemSource
([0]
和 [1]
)生成的代码,它包装了一个 LocalSource
(L['l']
)。
至此,我们有了 Source 和 Guard,我们就能够实现一个缓存系统,以避免不必要的重新编译,而无需每次都重新追踪。我们将在后续文章中更详细地讨论这个缓存系统。
细心的读者会注意到,这还没有解释为什么我们需要对 Python 解释器进行如此精细地控制,以至于需要重新实现它。我们展示的 Guard 示例依赖于输入对象,因此我们仍然可以在执行函数之前计算它们。换句话说,我们可以将这个 Guard 系统实现在 torch.jit.trace
的顶层,并以更少的精力获得相同的功能…进入符号形状。
符号形状#
我们在介绍中讨论的另一个观点是,Dynamo 知道如何追踪整数。为了实现这一点,我们使用一个符号类 torch.SymInt,它像一个 int
一样工作,但它会将所有对其执行的操作记录到输出的 FX 图中。4 在介绍中介绍符号整数追踪时,我们已经见过这个类。
现在让我们讨论定义 Dynamo 中符号形状追踪的三种属性以及如何实现它们。
默认静态#
Dynamo 假定每个整数,无论是输入还是张量的形状,默认都是静态的。换句话说,在函数第一次执行时不会追踪任何整数。然后,只有当它检测到在执行过程中整数或形状的值发生了变化时,它才会追踪它并生成一个关于该变量的通用图。
我们在介绍中已经使用整数看到了这种行为。现在让我们来看一个使用张量形状的示例。
import torch
@torch.compile
def fn(a, b):
return a.shape[0] * a * b
fn(torch.randn(4, 3), torch.randn(4, 3))
fn(torch.randn(8, 3), torch.randn(8, 3))
使用 TORCH_LOGS=graph_code
运行此程序,我们看到这两个调用被追踪为
def forward(self, l_a_: torch.Tensor, l_b_: torch.Tensor):
mul = 4 * l_a_
mul_1 = mul * l_b_
return (mul_1,)
def forward(self, s0: torch.SymInt, l_a_: torch.Tensor, l_b_: torch.Tensor):
size = l_a_.size()
getitem = size[0]
mul = getitem * l_a_
mul_1 = mul * l_b_
return (mul_1,)
在第一个图中,形状被追踪为常量,但一旦它发生变化,它就会使用 SymInt
符号化地追踪它。通常,查看中间值的形状的更简单方法是使用 TORCH_LOGS=graph_sizes
运行程序
TRACED GRAPH TENSOR SIZES
===== __compiled_fn_1 =====
l_a_: (s0, 3)
l_a_ (concrete): (8, 3)
l_b_: (s0, 3)
l_b_ (concrete): (8, 3)
mul: (s0, 3)
mul (concrete): (8, 3)
mul_1: (s0, 3)
mul_1 (concrete): (8, 3)
在那里,我们可以看到两个张量参数的第一个维度是动态的,因为它由 s0
变量表示。
我们可以通过运行 TORCH_LOGS=guards
来找到 Dynamo 实现这一点的方法
# Guards first call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
# Guards second call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
L['b'].size()[0] == L['a'].size()[0]
2 <= L['a'].size()[0]
我们看到在第一次调用时,Guard 检查张量是否具有固定的尺寸和步幅。这些 Guard 在第二次执行时会失败,所以它会重新追踪。由于这是一个 int
Guard 失败了,所以在第二次迭代中,它会将这个 int
符号化地追踪,并为这个更通用的内核安装更通用的 Guard。
编译性能技巧。如果您知道某个维度的大小会变化,可以在调用 torch.compile
之前通过调用 torch._dynamo.mark_dynamic 将其标记为动态。这将避免第一次使用静态形状进行编译。还有其他有用的实用函数,如 maybe_mark_dynamic
或 mark_static
。您也可以通过调用 torch.compile(dynamic=True)
来追踪所有整数和形状。这主要用于调试目的。
0 和 1 总是被专门化#
无论我们是否将某个维度标记为动态,如果我们传入一个该维度为 0 或 1 的输入,Dynamo 将会将其追踪为非动态,并为其生成一个特定的图。这就是为什么在上面的示例中我们会发现形式为 2 <= L['a'].size()[0]
的 Guard。
这个选择有几个原因。其中有两个尤其重要:- 张量为空当且仅当其任何维度为零- 张量仅在步幅之一为一时才能是连续的
这个策略决定不适用于普通 Python 整数;如果我们认为一个 Python 整数应该被动态编译,我们默认不会专门化它们;相反,它是否被专门化取决于它的用法。
鸭子类型形状(Duck Shaping)#
Dynamo 执行所谓的“鸭子类型形状”。如果两个动态整数在追踪时具有相同的值,我们将假设它们相等并为此建立 Guard。有效地,这意味着我们不再像上面的示例那样有两个符号 s0
、s1
,而是将它们统一为 s0
,并拥有 Guard L['b'].size()[0] == L['a'].size()[0]
。这使得在编译器内部进行融合成为可能,同时能够生成足够通用的内核。
对符号整数的 Guard#
我们现在在高层面上理解了符号形状是如何实现的以及它们的属性。那么,为什么符号形状迫使我们通过控制 CPython 解释器这种棘手的路线呢?考虑以下示例
import torch
@torch.compile(dynamic=True)
def fn(a):
if a.shape[0] * 2 < 16:
return a
else:
return a + 1
fn(torch.randn(8))
这段代码有一个形式为 2*L['a'].size()[0] >= 16
的 Guard。这是一个关于函数输入的非平凡 Guard,但它是在程序执行中间注册的。更重要的是,我们直到看到依赖于 SymNodeVariable
参数的 if
语句时,才能知道这个 Guard 是必需的。这些条件对 torch.jit.trace
是不可见的,并且需要对 Python 代码进行深度分析。
调试技巧。使用 TORCH_LOGS=dynamo
运行这段代码会告诉我们这个 Guard 是在哪里添加的
eval 2*s0 >= 16 [guard added] at script.py:5 in fn (_dynamo/variables/tensor.py:812 in evaluate_expr)
在该处设置断点并查看回溯对于理解 Guard 的来源非常有用。
使 Dynamo 完整:图中断(Graph Breaks)#
有了我们讨论过的所有工具,我们就拥有了一个追踪器,它可以追踪张量和整数上的 PyTorch 操作,并拥有一个知道何时可以重用先前追踪的图以及何时需要重新追踪的缓存系统。所有这一切都是在执行任意 Python 代码!
这里只有一个小问题。陈述“执行任意 Python 代码”可能有点太笼统了。Dynamo 实现了一大部分 Python,但它是否实现了协程或异步等更复杂的部分?它是否实现了整个 Python 标准库?NumPy 也有一个 Python API。 torch.compile
是否也理解 NumPy?以及 Django?5
Python 的生态系统非常庞大,其中很大一部分是用 C++ 或 Rust 等更高效的语言编写的,它们只是暴露了 Python 绑定。Dynamo 无法追踪用 C++ 实现的 Python 对象。当追踪器遇到它不理解的操作时,它能做什么?
机器学习追踪器处理这个问题的常规方法是告知用户它们卡住的操作,然后完全放弃追踪。在 PyTorch 的情况下,这会带来真正的可用性问题,因为它的用户习惯了它提供的灵活性。例如,doctr_det_predictor
模型使用 NumPy 和 cv2
库来后处理模型的输出。
这里是 CPython 的另一个有趣之处。Dynamo 不是抛出错误,而是可以让 CPython 运行那个有问题的代码!为了做到这一点,Dynamo 在追踪时生成一个图,包含有问题代码之前的所有操作,以及一个包含有问题代码之后的所有操作的图。6 然后,在运行时,它会将执行第一个图、有问题的代码以及第二个图的任务委托给 CPython。这种停止追踪并生成多个图的过程称为图中断。
一个小小的坦白:我在整个介绍和前几节都说了谎。Dynamo 不只生成一个图,而是生成多个图!对于所有实际目的,在第二个图之后开始重新追踪可以被认为是在开始追踪一个新函数。图中断后的新图将有自己的 Guard,自己的一组局部变量,等等。
为了讨论如何实现图中断,我们首先需要回顾一下 Dynamo 如何与 CPython 交互。使用 PEP 523,CPython 允许用户使用自己的帧评估机制。我们还没有讨论的是,CPython 还公开了自己的帧评估供他人使用。Dynamo 利用这一点让快速的 CPython 解释器运行编译后的代码。对于没有图中断的函数,程序第一次和第二次调用函数(参数相同)的整个追踪/执行过程如下:
在第一次调用函数时
Dynamo 将函数追踪成一个 FX 图
FX 图由编译器(Inductor)编译成高效的低级代码…但这又是另一个故事了
它重写函数的字节码,使其只需调用编译后的函数
它将这个新的字节码提供给 CPython,并要求它运行它在此处
在第二次调用函数时
这个过程本身看起来过于复杂。为什么还要生成新的字节码并要求 CPython 运行它,而不是直接创建一个 C++ 绑定到编译后的函数并执行它?嗯,这个模式允许我们实现图中断!由图中断生成的字节码具有以下结构:
执行第一个图的字节码
离开堆栈的字节码,就像 CPython 执行第一个图时的状态一样。它还会重放当时可见的对局部或全局变量的任何修改
导致 Dynamo 图中断的字节码
执行第二个图的字节码
让我们看一个简单的例子
import torch
@torch.compile
def fn(a):
b = a + 2
print("Hi")
return b + a
fn(torch.randn(4))
使用 TORCH_LOGS=bytecode
运行此代码会向我们显示初始字节码和修改后的字节码
MODIFIED BYTECODE fn script.py line 3
0 LOAD_GLOBAL 1 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 CALL_FUNCTION 1
6 STORE_FAST 3 (graph_out_0)
8 LOAD_GLOBAL 0 (print)
10 LOAD_CONST 2 ('Hi')
12 LOAD_FAST 3 (graph_out_0)
14 LOAD_CONST 3 (0)
16 BINARY_SUBSCR
18 STORE_FAST 1 (b)
20 CALL_FUNCTION 1
22 LOAD_GLOBAL 2 (__resume_at_14_1)
24 ROT_TWO
26 LOAD_FAST 0 (a)
28 LOAD_FAST 1 (b)
30 CALL_FUNCTION 3
32 RETURN_VALUE
MODIFIED BYTECODE resume_in_fn script.py line 6
0 LOAD_GLOBAL 1 (__compiled_fn_2)
2 LOAD_FAST 2 (b)
4 LOAD_FAST 1 (a)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 1
10 RETURN_VALUE
我们可以看到修改后的字节码被分成两个函数:fn
(原始函数)和一个名为 resume_in_fn
的函数。第二个函数是由 Dynamo 创建的,用于实现从图中断开始的程序执行。这通常被称为续延函数 (continuation function)。此续延函数只需使用正确的参数调用第二个编译后的函数。初始函数的代码被重写,实现了我们之前描述的策略:
L0-4. 调用编译后的函数(
a + 2
)。L6. 将其结果存储在名为
graph_out_0
的局部变量中。graph_out_0
是一个元组L8-18. 离开堆栈,使其处于图中断时的状态
L20. 执行导致图中断的代码
L22-32. 调用编译后的续延函数(
a + b
)
Dynamo 中堆栈的代码生成委托给了 VariableTracker
子类。Dynamo 中的每个 VariableTracker
对象都有一个 reconstruct 方法,该方法生成必要的字节码以在堆栈上创建它所代表的 Python 对象。
调试技巧。图中断会影响性能,因此最好避免它们。使用 TORCH_LOGS=graph_breaks
运行程序是查找程序触发了多少图中断的一个好方法。它返回的信息是关于 VariableTracker
对象,所以上面的调试技巧有时也有助于弄清楚是什么原因导致了该图中断。
结论#
Dynamo 是一个复杂的软件。一旦你决定实现一个 CPython 解释器,你就知道这是一段艰难的旅程。话虽如此,我们希望这篇文章能帮助您对其进行一些解释。
Dynamo (主要)是用 Python 实现的。我们留下了许多指向我们讨论过的代码片段的链接。我们希望阅读这些代码片段,然后搜索调用它们的地方,或者在它们上面设置断点并查看调用堆栈,有助于理解代码库的其余部分。
当然,学习软件工作原理的最佳方法是扩展它。在这种情况下,最好的方法是查看 GitHub 上的开放 Dynamo 问题。其中许多只需要对代码进行非常小的更改,一旦您找到需要进行更改的地方。