已编译的自动梯度:捕获更大的反向传播图以用于 torch.compile#
创建于:2024 年 10 月 09 日 | 最后更新:2024 年 10 月 23 日 | 最后验证:2024 年 10 月 09 日
作者: Simon Fan
已编译的自动梯度如何与
torch.compile交互如何使用已编译的自动梯度 API
如何使用
TORCH_LOGS检查日志
PyTorch 2.4
阅读 PyTorch 2.x 入门 中的 TorchDynamo 和 AOTAutograd 部分
概述#
已编译的自动梯度是 PyTorch 2.4 中引入的一个 torch.compile 扩展,它允许捕获更大的反向传播图。
虽然 torch.compile 会捕获反向传播图,但它是部分捕获的。AOTAutograd 组件会提前捕获反向传播图,但存在一些限制。
前向传播中的图中断会导致反向传播中的图中断
反向传播钩子未被捕获
已编译的自动梯度通过直接与自动梯度引擎集成来解决这些限制,允许它在运行时捕获完整的反向传播图。具有这两种特性的模型应尝试使用已编译的自动梯度,并可能观察到更好的性能。
然而,已编译的自动梯度也引入了自己的限制
在反向传播开始时增加了缓存查找的运行时开销
由于捕获范围更大,更容易发生重新编译和图中断
注意
已编译的自动梯度正在积极开发中,尚未与所有现有的 PyTorch 功能兼容。有关特定功能的最新状态,请参阅 已编译自动梯度登陆页面。
设置#
在本教程中,我们将基于这个简单的神经网络模型来举例。它接收一个 10 维的输入向量,通过一个线性层进行处理,并输出另一个 10 维向量。
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
基本用法#
在调用 torch.compile API 之前,请确保将 torch._dynamo.config.compiled_autograd 设置为 True。
model = Model()
x = torch.randn(10)
torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
loss = model(x).sum()
loss.backward()
train(model, x)
在上面的代码中,我们创建了一个 Model 类的实例,并使用 torch.randn(10) 生成了一个随机的 10 维张量 x。我们定义了训练循环函数 train,并使用 @torch.compile 装饰它以优化其执行。当调用 train(model, x) 时:
Python 解释器调用 Dynamo,因为此调用被装饰了
@torch.compile。Dynamo 拦截 Python 字节码,模拟其执行并将操作记录到图中。
AOTDispatcher禁用钩子,并调用自动梯度引擎来计算model.linear.weight和model.linear.bias的梯度,并将操作记录到图中。使用torch.autograd.Function,AOTDispatcher 重写了train的前向和反向传播实现。Inductor 生成一个对应于 AOTDispatcher 前向和反向传播优化实现的函数。
Dynamo 设置优化后的函数,以便 Python 解释器接下来进行评估。
Python 解释器执行优化后的函数,该函数执行
loss = model(x).sum()。Python 解释器执行
loss.backward(),调用自动梯度引擎,该引擎会路由到已编译的自动梯度引擎,因为我们将torch._dynamo.config.compiled_autograd = True设置为 True。已编译的自动梯度计算
model.linear.weight和model.linear.bias的梯度,并将操作记录到图中,包括它遇到的任何钩子。在此过程中,它将记录 AOTDispatcher 之前重写的反向传播。然后,已编译的自动梯度生成一个新函数,该函数对应于loss.backward()的完全跟踪实现,并以推理模式使用torch.compile执行它。相同的步骤将递归应用于已编译的自动梯度图,但这次 AOTDispatcher 将不需要划分图。
检查已编译的自动梯度日志#
使用 TORCH_LOGS 环境变量运行脚本。
要仅打印已编译的自动梯度图,请使用
TORCH_LOGS="compiled_autograd" python example.py。要以损失性能为代价,打印带有更多张量元数据和重新编译原因的图,请使用
TORCH_LOGS="compiled_autograd_verbose" python example.py。
重新运行上面的片段,已编译的自动梯度图现在应该被记录到 stderr。某些图节点将带有以 aot0_ 为前缀的名称,这些名称对应于之前在 AOTAutograd 反向传播图 0 中预先编译的节点,例如,aot0_view_2 对应于 ID 为 0 的 AOT 反向传播图的 view_2。
在下面的图像中,红色框包含了在没有已编译自动梯度的情况下被 torch.compile 捕获的 AOT 反向传播图。
注意
这是我们将调用 torch.compile 的图,而不是优化后的图。已编译的自动梯度本质上会生成一些未优化的 Python 代码来表示整个 C++ 自动梯度执行。
使用不同的标志编译前向和反向传播#
您可以使用不同的编译器配置进行两次编译,例如,即使前向传播中有图中断,反向传播也可以是 fullgraph。
def train(model, x):
model = torch.compile(model)
loss = model(x).sum()
torch._dynamo.config.compiled_autograd = True
torch.compile(lambda: loss.backward(), fullgraph=True)()
或者,您可以使用上下文管理器,它将应用于其作用域内的所有自动梯度调用。
def train(model, x):
model = torch.compile(model)
loss = model(x).sum()
with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
loss.backward()
已编译的自动梯度解决了 AOTAutograd 的某些限制#
前向传播中的图中断不再必然导致反向传播中的图中断。
@torch.compile(backend="aot_eager")
def fn(x):
# 1st graph
temp = x + 10
torch._dynamo.graph_break()
# 2nd graph
temp = temp + 10
torch._dynamo.graph_break()
# 3rd graph
return temp.sum()
x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)
# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()
# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()
# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)
在第一个 torch.compile 案例中,我们看到由于编译函数 fn 中的 2 次图中断,产生了 3 个反向传播图。而在第二个使用已编译自动梯度的 torch.compile 案例中,我们看到即使存在图中断,也跟踪了一个完整的反向传播图。
注意
Dynamo 在跟踪已编译自动梯度捕获的反向传播钩子时,仍有可能发生图中断。
现在可以捕获反向传播钩子了。
@torch.compile(backend="aot_eager")
def fn(x):
return x.sum()
x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()
图中应该有一个 call_hook 节点,Dynamo 稍后会将其内联到以下内容:
已编译自动梯度的常见重新编译原因#
由于损失值自动梯度结构的变化
torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
loss = op(x, x).sum()
torch.compile(lambda: loss.backward(), backend="eager")()
在上面的示例中,我们在每次迭代时调用不同的运算符,导致 loss 跟踪不同的自动梯度历史。您应该会看到一些重新编译消息:由于新的自动梯度节点导致缓存未命中。
由于张量形状发生变化
torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
x = torch.randn(i, i, requires_grad=True)
loss = x.sum()
torch.compile(lambda: loss.backward(), backend="eager")()
在上面的示例中,x 的形状发生变化,在第一次变化后,已编译的自动梯度会将 x 标记为动态形状张量。您应该会看到重新编译消息:由于形状变化导致缓存未命中。
结论#
在本教程中,我们概述了 torch.compile 与已编译自动梯度的生态系统、已编译自动梯地的基础知识以及一些常见的重新编译原因。请继续关注 dev-discuss 上的深度探讨。