使用 @assume_pure
加速追踪¶
本文档介绍如何使用 torch_xla.experimental.assume_pure
来消除惰性张量追踪的开销。有关惰性张量追踪(操作记录)工作原理的入门介绍,请参阅这篇博文。
背景和动机¶
PyTorch/XLA 的惰性张量追踪通过记录 PyTorch 操作时的操作图(惰性张量 IR)来确保正确执行。对于复杂的模型,此追踪开销可能会超过图的执行时间,从而导致性能瓶颈。在训练模型时,模型的层在每个训练步骤上都必须重新追踪。这是因为不能保证这些层在不同的训练步骤中执行相同的操作。例如,一个层的 forward()
函数可能会调用 math.random()
,并根据伪随机数决定执行哪段代码。
重新追踪会引入不必要的开销。在许多情况下,当给定相同的输入张量形状时,模型中的层会执行完全相同的操作。换句话说,给定相同的输入,函数会返回相同的输出。通常,这些层也不会执行副作用,例如将张量保存到文件或将其添加到全局列表中。此类函数称为“纯函数”。
任何用 @assume_pure
装饰的 PyTorch/XLA 函数,对于每种唯一的输入张量形状和数据类型组合,都只会追踪一次。PyTorch/XLA 会缓存追踪到的计算,而不是重复追踪相同的操作。
如何使用 @assume_pure
¶
将 @assume_pure
用于函数¶
如果你知道你的函数是纯函数,请用 @assume_pure
装饰你的函数。
import torch
import torch_xla
from torch_xla.experimental.assume_pure import assume_pure
@assume_pure
def do_some_math(
# You can pass any number of XLA tensors.
a: torch.Tensor,
b: torch.Tensor,
# Non-tensor arguments are also supported, and passing different values will
# trigger re-tracing and caching more computations.
c: int,
):
# Evaluate some pure expressions.
return a @ b + c
# Simulate a training loop.
# Even if we run this function ten times, it will only be traced once.
for i in range(10):
v = do_some_math(
torch.tensor([1.0], device='xla'),
torch.tensor([2.0], device='xla'),
c=42,
)
print(v)
将 @assume_pure
用于 nn.Module
¶
如果你有一个纯 nn.Module
,即其 forward
行为仅取决于输入参数和模型参数,我们可以使用 torch.func.functional_call
将该模块转换为纯函数,然后将其传递给 assume_pure
。
import torch
import torch.nn as nn
from torch.func import functional_call
from torch_xla.experimental.assume_pure import assume_pure
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
# Create module and move to XLA device
module = MyModule()
module = module.to('xla')
# Convert module's forward pass into a pure function
pure_forward = lambda params, buffers, x: functional_call(module, (params, buffers), (x,))
# Wrap the pure function with @assume_pure
cached_forward = assume_pure(pure_forward)
# Simulate a training loop
# Even if we run the model ten times, its forward function will only be traced once.
params = dict(module.named_parameters())
buffers = dict(module.named_buffers())
for i in range(10):
x = torch.randn(5, 10, device='xla')
y = cached_forward(params, buffers, x)
print(y)
基准测试¶
单元测试包含一个基准测试,该测试追踪了一个包含 100 层解码器独占语言模型的示例。
~/pytorch/xla
❯ TESTBRIDGE_TEST_ONLY=test_trace_transformer_with_spda_attention python3 test/test_assume_pure.py --benchmark_iterations 100
[...]
No `@assume_pure` time: 140.1342 ms
`@assume_pure` time: 24.1658 ms
使用 @assume_pure
的版本速度快得多。
重要的是,@assume_pure
的运行时间不会随着模型内部复杂度的增加而增加。这是因为我们只追踪模型一次,支付固定的前期成本,然后后续运行将重用缓存的 XLA 计算。
局限性¶
当前,用 @assume_pure
包装的函数中的所有操作都必须是 PyTorch 上游操作(例如 torch.einsum
、torch.sin
等),或者是这些 PyTorch/XLA 操作:
torch_xla.experimental.assume_pure
(递归assume_pure
)torch_xla.distributed.spmd.mark_sharding
未来将支持更多 PyTorch/XLA 操作(例如 flash_attention
)。