使用 scan
和 scan_layers
优化重复层¶
本指南介绍了如何在 PyTorch/XLA 中使用 scan
和 scan_layers
。
何时应使用此方法¶
如果您有一个模型包含许多同质(形状相同、逻辑相同)的层,例如 LLMs,可以考虑使用 scan_layers
。这些模型的编译可能会很慢。scan_layers
可以直接替换掉同质层(例如,一系列解码器层)上的 for 循环。scan_layers
会跟踪第一层并为所有后续层重用编译结果,从而显著减少模型的编译时间。
另一方面,scan
是一个更底层的、高阶算子,它模仿了 jax.lax.scan
。它的主要目的是在底层实现 scan_layers
。然而,您可能会发现它对于编程循环逻辑很有用,在这种逻辑中,循环本身在编译器中具有第一类表示(特别是,XLA 的 while
算子)。
scan_layers
示例¶
通常,Transformer 模型将输入嵌入通过一系列同质解码器层。
def run_decoder_layers(self, hidden_states):
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states)
return hidden_states
当此函数被降低到 HLO 图时,for 循环将被展开成一个扁平的操作列表,导致编译时间长。为了减少编译时间,请将 for 循环替换为 scan_layers
,如 decoder_with_scan.py
所示。
def run_decoder_layers(self, hidden_states):
from torch_xla.experimental.scan_layers import scan_layers
return scan_layers(self.layers, hidden_states)
您可以从 pytorch/xla
源代码检出的根目录运行以下命令来训练此解码器模型。
python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan
scan
示例¶
scan
接受一个组合函数,并在处理张量的首个维度时应用该函数,同时携带状态。
def scan(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
) -> tuple[Carry, Y]:
...
使用它来高效地遍历张量的首个维度。如果 xs
是一个单一的张量,这个函数大致等于以下 Python 代码:
def scan(fn, init, xs):
ys = []
carry = init
for i in len(range(xs.size(0))):
carry, y = fn(carry, xs[i])
ys.append(y)
return carry, torch.stack(ys, dim=0)
在底层,scan
通过将循环降低为一个 XLA while
操作来实现高效。这确保了 XLA 只编译循环的一个迭代。
scan_examples.py
包含一些展示如何使用 scan
的示例代码。scan_example_cumsum
在该文件中使用 scan
来实现累积和。scan_example_pytree
展示了如何将 PyTrees 传递给 scan
。
您可以使用以下命令运行示例:
python3 examples/scan/scan_examples.py
输出应类似如下:
Running example: scan_example_cumsum
Final sum: tensor([6.], device='xla:0')
History of sums tensor([[1.],
[3.],
[6.]], device='xla:0')
Running example: scan_example_pytree
Final carry: {'sum': tensor([15.], device='xla:0'), 'count': tensor([5.], device='xla:0')}
Means over time: tensor([[1.0000],
[1.5000],
[2.0000],
[2.5000],
[3.0000]], device='xla:0')
使用 scan 缓存¶
由于 scan
使用 AOTAutograd 来确定输入函数/模块在每次迭代中的反向传播,与 for 循环实现相比,它很容易受到跟踪的限制。随着 scan 缓存的实现,用户可以指明提供的函数(或使用 scan_layers
时的层)是纯函数,以激活缓存机制。缓存显著减少了跟踪开销,尤其是在迭代次数很多的情况下。以下命令训练了一个示例解码器。在没有缓存的情况下,它需要 9m35s
,而使用缓存则需要 4m59s
。
time python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan --hidden-size 128 --num-layers 2 --num-attention-heads 8 --num-key-value-heads 4 --intermediate-size 512 --num-steps 500 --print-metrics --is-decoder-layer-pure
要启用缓存,只需将 is_fn_pure
(或在使用 scan_layers
时使用 is_layer_pure
)设置为 True
。例如:
final_carry, ys = scan(fn, init_scan, xs_scan, is_fn_pure=is_fn_pure)
scan_layers(layers, input_data, is_layer_pure=True)
限制¶
AOTAutograd 兼容性要求¶
传递给 scan
和 scan_layers
的函数/模块必须是 AOTAutograd 可跟踪的。特别地,截至 PyTorch/XLA 2.6,scan
和 scan_layers
无法跟踪带有自定义 Pallas 内核的函数。这意味着如果您的解码器使用了,例如 flash attention,那么它与 scan
不兼容。我们正在努力 支持这一重要的用例。
编译时间实验¶
为了演示编译时间的节省,我们将使用 for 循环与 scan_layers
在单个 TPU 芯片上训练一个具有许多层的简单解码器。
运行 for 循环实现
❯ python3 examples/train_decoder_only_base.py \
--hidden-size 256 \
--num-layers 50 \
--num-attention-heads 4 \
--num-key-value-heads 2 \
--intermediate-size 2048 \
--num-steps 5 \
--print-metrics
...
Metric: CompileTime
TotalSamples: 3
Accumulator: 02m57s694ms418.595us
ValueRate: 02s112ms586.097us / second
Rate: 0.054285 / second
Percentiles: 1%=023ms113.470us; 5%=023ms113.470us; 10%=023ms113.470us; 20%=023ms113.470us; 50%=54s644ms733.284us; 80%=01m03s028ms571.841us; 90%=01m03s028ms571.841us; 95%=01m03s028ms571.841us;
99%=01m03s028ms571.841us
运行
scan_layers
实现
❯ python3 examples/train_decoder_only_base.py \
scan.decoder_with_scan.DecoderWithScan \
--hidden-size 256 \
--num-layers 50 \
--num-attention-heads 4 \
--num-key-value-heads 2 \
--intermediate-size 2048 \
--num-steps 5 \
--print-metrics
...
Metric: CompileTime
TotalSamples: 3
Accumulator: 29s996ms941.409us
ValueRate: 02s529ms591.388us / second
Rate: 0.158152 / second
Percentiles: 1%=018ms636.571us; 5%=018ms636.571us; 10%=018ms636.571us; 20%=018ms636.571us; 50%=11s983ms003.171us; 80%=18s995ms301.667us; 90%=18s995ms301.667us; 95%=18s995ms301.667us;
99%=18s995ms301.667us
通过切换到 scan_layers
,最大编译时间从 1m03s
减少到 19s
。
参考资料¶
有关 scan
和 scan_layers
本身的设计,请参阅 https://github.com/pytorch/xla/issues/7253。
有关如何使用 scan
和 scan_layers
的详细信息,请参阅 scan
和 scan_layers
的函数文档注释。