• 文档 >
  • 使用 @assume_pure 加快跟踪速度
快捷方式

使用 @assume_pure 加速追踪

本文档介绍如何使用 torch_xla.experimental.assume_pure 来消除惰性张量追踪的开销。有关惰性张量追踪(操作记录)工作原理的入门介绍,请参阅这篇博文

背景和动机

PyTorch/XLA 的惰性张量追踪通过记录 PyTorch 操作时的操作图(惰性张量 IR)来确保正确执行。对于复杂的模型,此追踪开销可能会超过图的执行时间,从而导致性能瓶颈。在训练模型时,模型的层在每个训练步骤上都必须重新追踪。这是因为不能保证这些层在不同的训练步骤中执行相同的操作。例如,一个层的 forward() 函数可能会调用 math.random(),并根据伪随机数决定执行哪段代码。

重新追踪会引入不必要的开销。在许多情况下,当给定相同的输入张量形状时,模型中的层会执行完全相同的操作。换句话说,给定相同的输入,函数会返回相同的输出。通常,这些层也不会执行副作用,例如将张量保存到文件或将其添加到全局列表中。此类函数称为“纯函数”。

任何用 @assume_pure 装饰的 PyTorch/XLA 函数,对于每种唯一的输入张量形状和数据类型组合,都只会追踪一次。PyTorch/XLA 会缓存追踪到的计算,而不是重复追踪相同的操作。

如何使用 @assume_pure

@assume_pure 用于函数

如果你知道你的函数是纯函数,请用 @assume_pure 装饰你的函数。

import torch
import torch_xla
from torch_xla.experimental.assume_pure import assume_pure

@assume_pure
def do_some_math(
    # You can pass any number of XLA tensors.
    a: torch.Tensor,
    b: torch.Tensor,

    # Non-tensor arguments are also supported, and passing different values will
    # trigger re-tracing and caching more computations.
    c: int,
):
    # Evaluate some pure expressions.
    return a @ b + c

# Simulate a training loop.
# Even if we run this function ten times, it will only be traced once.
for i in range(10):
    v = do_some_math(
        torch.tensor([1.0], device='xla'),
        torch.tensor([2.0], device='xla'),
        c=42,
    )
    print(v)

@assume_pure 用于 nn.Module

如果你有一个纯 nn.Module,即其 forward 行为仅取决于输入参数和模型参数,我们可以使用 torch.func.functional_call 将该模块转换为纯函数,然后将其传递给 assume_pure

import torch
import torch.nn as nn
from torch.func import functional_call
from torch_xla.experimental.assume_pure import assume_pure

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)
    def forward(self, x):
        return self.linear(x)

# Create module and move to XLA device
module = MyModule()
module = module.to('xla')

# Convert module's forward pass into a pure function
pure_forward = lambda params, buffers, x: functional_call(module, (params, buffers), (x,))

# Wrap the pure function with @assume_pure
cached_forward = assume_pure(pure_forward)

# Simulate a training loop
# Even if we run the model ten times, its forward function will only be traced once.
params = dict(module.named_parameters())
buffers = dict(module.named_buffers())
for i in range(10):
    x = torch.randn(5, 10, device='xla')
    y = cached_forward(params, buffers, x)
    print(y)

基准测试

单元测试包含一个基准测试,该测试追踪了一个包含 100 层解码器独占语言模型的示例。

~/pytorch/xla
❯ TESTBRIDGE_TEST_ONLY=test_trace_transformer_with_spda_attention python3 test/test_assume_pure.py --benchmark_iterations 100
[...]
No `@assume_pure` time: 140.1342 ms
`@assume_pure` time: 24.1658 ms

使用 @assume_pure 的版本速度快得多。

重要的是,@assume_pure 的运行时间不会随着模型内部复杂度的增加而增加。这是因为我们只追踪模型一次,支付固定的前期成本,然后后续运行将重用缓存的 XLA 计算。

局限性

当前,用 @assume_pure 包装的函数中的所有操作都必须是 PyTorch 上游操作(例如 torch.einsumtorch.sin 等),或者是这些 PyTorch/XLA 操作:

  • torch_xla.experimental.assume_pure(递归 assume_pure

  • torch_xla.distributed.spmd.mark_sharding

未来将支持更多 PyTorch/XLA 操作(例如 flash_attention)。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源