• 文档 >
  • 使用 scanscan_layers 的指南
快捷方式

scan 和 scan_layers 使用指南

这是 PyTorch/XLA 中使用 scanscan_layers 的指南。

何时使用此功能

如果您有一个包含许多同构(形状相同、逻辑相同)层(例如 LLM)的模型,则应考虑使用 ``scan_layers` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py>`_。这些模型的编译速度可能很慢。scan_layers 是同构层(如多个解码器层)的 for 循环的直接替代品。scan_layers 会跟踪第一层,并为所有后续层重用编译结果,从而显著减少模型编译时间。

另一方面,``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ 是一个更低级的高阶算子,其模型基于 ``jax.lax.scan` <https://jax.net.cn/en/latest/_autosummary/jax.lax.scan.html>`_。它的主要目的是在后台帮助实现 scan_layers。但是,如果您想编程某种循环逻辑,其中循环本身在编译器(特别是 XLA While 算子)中具有头等表示,您可能会发现它很有用。

scan_layers 示例

通常,Transformer 模型会通过一系列同构解码器层传递输入嵌入,如下所示

def run_decoder_layers(self, hidden_states):
  for decoder_layer in self.layers:
    hidden_states = decoder_layer(hidden_states)
  return hidden_states

当此函数被降低到 HLO 图时,for 循环会被展开成一个扁平的操作列表,导致编译时间过长。为了缩短编译时间,您可以将 for 循环替换为对 scan_layers 的调用,如 ``decoder_with_scan.py` </examples/scan/decoder_with_scan.py>`_ 所示

def run_decoder_layers(self, hidden_states):
  from torch_xla.experimental.scan_layers import scan_layers
  return scan_layers(self.layers, hidden_states)

您可以通过从 pytorch/xla 源代码签出目录的根目录运行以下命令来训练此解码器模型。

python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan

scan 示例

``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ 接受一个组合函数,并将其应用于张量的领先维度,同时携带状态

def scan(
    fn: Callable[[Carry, X], tuple[Carry, Y]],
    init: Carry,
    xs: X,
) -> tuple[Carry, Y]:
  ...

您可以使用它来高效地循环遍历张量的领先维度。如果 xs 是一个单独的张量,则此函数大致等于以下 Python 代码

def scan(fn, init, xs):
  ys = []
  carry = init
  for i in len(range(xs.size(0))):
    carry, y = fn(carry, xs[i])
    ys.append(y)
  return carry, torch.stack(ys, dim=0)

在后台,scan 通过将循环降低到 XLA While 操作来实现,效率更高。这确保了 XLA 只编译循环的一个迭代。

``scan_examples.py` </examples/scan/scan_examples.py>`_ 包含一些展示如何使用 scan 的示例代码。在该文件中,scan_example_cumsum 使用 scan 实现累积求和。scan_example_pytree 展示了如何将 PyTrees 传递给 scan

您可以使用以下命令运行示例

python3 examples/scan/scan_examples.py

输出应类似于以下内容

Running example: scan_example_cumsum
Final sum: tensor([6.], device='xla:0')
History of sums tensor([[1.],
        [3.],
        [6.]], device='xla:0')


Running example: scan_example_pytree
Final carry: {'sum': tensor([15.], device='xla:0'), 'count': tensor([5.], device='xla:0')}
Means over time: tensor([[1.0000],
        [1.5000],
        [2.0000],
        [2.5000],
        [3.0000]], device='xla:0')

限制

AOTAutograd 兼容性要求

scanscan_layers 的函数/模块必须是 AOTAutograd 可跟踪的。特别是,截至 PyTorch/XLA 2.6,scanscan_layers 无法跟踪带有自定义 Pallas 内核的函数。这意味着如果您的解码器使用了,例如闪存注意力,那么它与 scan 不兼容。我们正在努力在 nightly 版本和下一个版本中支持此重要用例(https://github.com/pytorch/xla/issues/8633)。

AOTAutograd 开销

因为 scan 使用 AOTAutograd 来确定每个迭代中输入函数/模块的后向传递,所以与 for 循环实现相比,它很容易受到跟踪瓶颈的影响。事实上,截至 PyTorch/XLA 2.6,train_decoder_only_base.py 示例在 scan 下比使用 for 循环运行得更慢,这是由于此开销所致。我们正在努力提高跟踪速度(https://github.com/pytorch/xla/issues/8632)。当模型非常大或层数很多时,这不成问题,而这正是您想要使用 scan 的情况。

编译时间实验

为了演示编译时间节省,我们将使用 for 循环与 scan_layers 在单个 TPU 芯片上训练一个具有许多层的简单解码器。

  • 运行 for 循环实现

 python3 examples/train_decoder_only_base.py \
    --hidden-size 256 \
    --num-layers 50 \
    --num-attention-heads 4 \
    --num-key-value-heads 2 \
    --intermediate-size 2048 \
    --num-steps 5 \
    --print-metrics

...

Metric: CompileTime
  TotalSamples: 3
  Accumulator: 02m57s694ms418.595us
  ValueRate: 02s112ms586.097us / second
  Rate: 0.054285 / second
  Percentiles: 1%=023ms113.470us; 5%=023ms113.470us; 10%=023ms113.470us; 20%=023ms113.470us; 50%=54s644ms733.284us; 80%=01m03s028ms571.841us; 90%=01m03s028ms571.841us; 95%=01m03s028ms571.841us;
  99%=01m03s028ms571.841us
  • 运行 scan_layers 实现

 python3 examples/train_decoder_only_base.py \
    scan.decoder_with_scan.DecoderWithScan \
    --hidden-size 256 \
    --num-layers 50 \
    --num-attention-heads 4 \
    --num-key-value-heads 2 \
    --intermediate-size 2048 \
    --num-steps 5 \
    --print-metrics

...

Metric: CompileTime
  TotalSamples: 3
  Accumulator: 29s996ms941.409us
  ValueRate: 02s529ms591.388us / second
  Rate: 0.158152 / second
  Percentiles: 1%=018ms636.571us; 5%=018ms636.571us; 10%=018ms636.571us; 20%=018ms636.571us; 50%=11s983ms003.171us; 80%=18s995ms301.667us; 90%=18s995ms301.667us; 95%=18s995ms301.667us;
  99%=18s995ms301.667us

通过切换到 scan_layers,我们可以看到最长编译时间从 1m03s 降低到 19s

参考资料

有关 scanscan_layers 本身设计的更多信息,请参阅 https://github.com/pytorch/xla/issues/7253

有关如何使用 scanscan_layers 的详细信息,请参阅 ``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ 和 ``scan_layers` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py>`_ 的函数文档注释。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源