注意
转到页面底部下载完整示例代码。
通过使用嵌套张量(Nested Tensors)和 torch.compile() 替换 nn.Transformer 来加速 PyTorch Transformer#
了解 PyTorch 提供的用于构建自定义 Transformer 层的底层构建块(嵌套张量、
scaled_dot_product_attention、torch.compile()和FlexAttention)以 MultiHeadAttention 为例,探索上述功能如何改善内存使用和性能
使用上述构建块探索高级自定义功能
PyTorch v.2.6.0 或更高版本
过去几年中,PyTorch 团队开发了各种底层功能,组合这些功能可以创建多种 Transformer 变体。其中包括:
具有
torch.jagged布局的嵌套张量(又称 NJT)scaled_dot_product_attentiontorch.compile()FlexAttention
本教程将简要概述上述技术,并演示如何组合它们以构建灵活且高性能的 Transformer 层,从而提升用户体验。
可以观察到,torch.nn 模块目前提供了各种与 Transformer 相关的层。特别是,它包含了 TransformerEncoderLayer、TransformerEncoder、TransformerDecoderLayer、TransformerDecoder、Transformer 和 MultiheadAttention。该系列层最初是根据《Attention is All You Need》论文实现的。本教程中讨论的组件在现有的 nn 层之上提供了更好的用户体验、灵活性和性能。
本教程适合我吗?#
如果您正在寻找 torch 库为编写自己的 Transformer 层提供的构建块及最佳实践,那么您来对地方了。请继续阅读!
如果您正在寻找主流 Transformer 架构的开箱即用实现,请注意有许多开源库提供了这些功能,包括:
如果您只对高性能注意力评分修改感兴趣,请查看 FlexAttention 博客,其中包含一个 掩码(mask)实验场。
介绍构建块#
首先,我们将简要介绍引言中提到的四项技术。
嵌套张量(Nested tensors)泛化了常规稠密张量的形状,允许使用相同的张量 API 表示参差不齐的数据。在 Transformer 的语境下,我们可以将嵌套张量视为表示可变序列长度的工具。它们消除了对显式填充(padding)和掩码(masking)这种易出错做法的需求(例如 nn.MultiHeadAttention 中的 key_padding_mask)。
scaled_dot_product_attention 是 \(\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V\) 的原语,它会自动分发到算子的融合实现或回退实现中。它在 eager 模式(即 PyTorch 的默认模式,操作在遇到时即时执行)下可以开箱即用,并且能与 torch.compile() 无缝集成。从 2.6 版本开始,它还将原生支持分组查询注意力(Grouped Query Attention)。
torch.compile() 是 2.0 版本引入的编译器,能够捕获 PyTorch 代码的计算图并对其执行各种优化,例如融合连续的操作算子。具有 torch.jagged 布局的嵌套张量和 scaled_dot_product_attention 可以与 compile 无缝协作。在 Transformer 的语境下,结合使用 compile、嵌套张量和 SDPA 的价值在于:compile 可以消除 eager 模式中存在的框架开销,并将 Transformer 中的一系列操作(如投影和激活)进行融合。
FlexAttention 是一个原语,允许用户在 softmax 操作之前修改注意力评分。它推广了上述 scaled_dot_product_attention 中的加法 B 项,允许进行任意计算。它需要配合 compile 才能获得良好的性能。
上述构建块就是“你需要的一切”(截至 2024 年 10 月)#
本节的主要前提是:大多数 Transformer 变体都是 GPT 风格的,由 Embedding、位置编码、注意力块和前馈网络等层组成。如果我们尝试对这一领域的差异进行分类,可能会得到:
层类型(激活函数,如
SwiGLU等;归一化函数,如RMSNorm等;位置编码,如正弦位置编码、旋转位置编码等)。层排序,例如在哪里应用归一化和位置编码。
注意力评分的修改,例如
ALiBi、相对位置偏差等。
在非编译器环境中,您可能会编写一个自定义 Transformer 并发现它虽然功能正确但运行缓慢。为了解决这个问题,您可能需要为特定的一系列操作开发自定义融合内核。而在编译器环境中,您可以只执行初步步骤,然后进行编译,从而从优化的性能中获益。
MultiheadAttention#
请记住,MultiheadAttention 接收查询(query)、键(key)和值(value),并由输入投影、scaled_dot_product_attention 算子和输出投影组成。我们在此想要演示的主要收获是,当我们用嵌套张量替换填充/掩码输入时所带来的提升。这种提升体现在三个方面:
用户体验:请记住,
nn.MultiheadAttention要求query、key和value为稠密的torch.Tensors。它还提供了一个key_padding_mask,用于屏蔽key中因批次内序列长度不同而产生的填充标记。由于nn.MHA中没有query_padding_mask,用户必须小心地对输出进行掩码/切片,以适应查询序列的长度。NestedTensor清晰地消除了对这种易出错的填充掩码的需求。内存:与其物化一个带有
[B, S]填充掩码的稠密[B, S, D]张量(其中B是批大小,S是批内最大序列长度,D是嵌入维度),嵌套张量允许您清晰地表示具有不同序列长度的批次。因此,输入和中间激活将使用更少的内存。性能:由于不需要物化填充,并且跳过了对填充的不必要计算,性能和内存使用均得到了改善。
我们将通过在 嵌套张量教程 中的 MultiheadAttention 层基础上进行构建,并将其与 nn.MultiheadAttention 层进行比较,来演示上述内容。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""
Computes multi-head attention. Supports nested or padded tensors.
Args:
E_q (int): Size of embedding dim for query
E_k (int): Size of embedding dim for key
E_v (int): Size of embedding dim for value
E_total (int): Total embedding dim of combined heads post input projection. Each head
has dim E_total // nheads
nheads (int): Number of heads
dropout (float, optional): Dropout probability. Default: 0.0
bias (bool, optional): Whether to add bias to input projection. Default: True
"""
def __init__(
self,
E_q: int,
E_k: int,
E_v: int,
E_total: int,
nheads: int,
dropout: float = 0.0,
bias=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.nheads = nheads
self.dropout = dropout
self._qkv_same_embed_dim = E_q == E_k and E_q == E_v
if self._qkv_same_embed_dim:
self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)
else:
self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs)
self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs)
E_out = E_q
self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs)
assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
self.E_head = E_total // nheads
self.bias = bias
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask=None,
is_causal=False,
) -> torch.Tensor:
"""
Forward pass; runs the following process:
1. Apply input projection
2. Split heads and prepare for SDPA
3. Run SDPA
4. Apply output projection
Args:
query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``)
key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``)
value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``)
attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None
is_causal (bool, optional): Whether to apply causal mask. Default: False
Returns:
attn_output (torch.Tensor): output of shape (N, L_t, E_q)
"""
# Step 1. Apply input projection
if self._qkv_same_embed_dim:
if query is key and key is value:
result = self.packed_proj(query)
query, key, value = torch.chunk(result, 3, dim=-1)
else:
q_weight, k_weight, v_weight = torch.chunk(
self.packed_proj.weight, 3, dim=0
)
if self.bias:
q_bias, k_bias, v_bias = torch.chunk(
self.packed_proj.bias, 3, dim=0
)
else:
q_bias, k_bias, v_bias = None, None, None
query, key, value = (
F.linear(query, q_weight, q_bias),
F.linear(key, k_weight, k_bias),
F.linear(value, v_weight, v_bias),
)
else:
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
# Step 2. Split heads and prepare for SDPA
# reshape query, key, value to separate by head
# (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
# Step 3. Run SDPA
# (N, nheads, L_t, E_head)
attn_output = F.scaled_dot_product_attention(
query, key, value, dropout_p=self.dropout, is_causal=is_causal
)
# (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
attn_output = attn_output.transpose(1, 2).flatten(-2)
# Step 4. Apply output projection
# (N, L_t, E_total) -> (N, L_t, E_out)
attn_output = self.out_proj(attn_output)
return attn_output
实用工具#
在本节中,我们包含了一个使用 Zipf 分布生成句子长度的半真实数据的工具。这用于生成嵌套的查询、键和值张量。我们还包含了一个基准测试工具。
import numpy as np
def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
# generate fake corpus by unigram Zipf distribution
# from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
sentence_lengths = np.empty(batch_size, dtype=int)
for ibatch in range(batch_size):
sentence_lengths[ibatch] = 1
word = np.random.zipf(alpha)
while word != 3 and word != 386 and word != 858:
sentence_lengths[ibatch] += 1
word = np.random.zipf(alpha)
return torch.tensor(sentence_lengths)
# Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
# in the form of nested tensors with the jagged layout.
def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False):
# generate semi-realistic data using Zipf distribution for sentence lengths
sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)
# Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
# dimension and works with torch.compile. The batch items each have shape (B, S*, D)
# where B = batch size, S* = ragged sequence length, and D = embedding dimension.
if query_seq_len_1:
query = torch.nested.nested_tensor(
[torch.randn(1, E_q, dtype=dtype, device=device) for l in sentence_lengths],
layout=torch.jagged,
)
else:
query = torch.nested.nested_tensor(
[
torch.randn(l.item(), E_q, dtype=dtype, device=device)
for l in sentence_lengths
],
layout=torch.jagged,
)
key = torch.nested.nested_tensor(
[
torch.randn(s.item(), E_k, dtype=dtype, device=device)
for s in sentence_lengths
],
layout=torch.jagged,
)
value = torch.nested.nested_tensor(
[
torch.randn(s.item(), E_v, dtype=dtype, device=device)
for s in sentence_lengths
],
layout=torch.jagged,
)
return query, key, value, sentence_lengths
import math
import timeit
def benchmark(func, *args, **kwargs):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
begin = timeit.default_timer()
output = func(*args, **kwargs)
torch.cuda.synchronize()
end = timeit.default_timer()
return output, (end - begin), torch.cuda.max_memory_allocated()
现在,我们将演示在自注意力机制中结合使用嵌套张量和 compile 后,在 MultiheadAttention 层上获得的性能提升。我们将此与传统的 nn.MultiheadAttention + compile(使用填充和掩码)进行了对比。
N, E_q, E_k, E_v, E_total = 512, 512, 512, 512, 512
E_out = E_q
d_model = E_q
nheads = 8
dropout = 0.0
bias = True
device = "cuda"
torch.manual_seed(6)
query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)
S = sentence_lengths.max().item()
print(
f"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}"
)
padded_query, padded_key, padded_value = (
t.to_padded_tensor(0.0) for t in (query, key, value)
)
torch.manual_seed(6)
mha_layer = MultiHeadAttention(
E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device="cuda"
)
torch.manual_seed(6)
vanilla_mha_layer = nn.MultiheadAttention(
E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device="cuda"
)
# ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :(
mha_layer.out_proj.weight = nn.Parameter(
vanilla_mha_layer.out_proj.weight.clone().detach()
)
mha_layer.packed_proj.weight = nn.Parameter(
vanilla_mha_layer.in_proj_weight.clone().detach()
)
mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())
mha_layer.packed_proj.bias = nn.Parameter(
vanilla_mha_layer.in_proj_bias.clone().detach()
)
new_mha_layer = torch.compile(mha_layer)
# warmup compile
nested_result_warmup = new_mha_layer(query, query, query, is_causal=True)
# benchmark
nested_result, nested_time, nested_peak_memory = benchmark(
new_mha_layer, query, query, query, is_causal=True
)
padded_nested_result = nested_result.to_padded_tensor(0.0)
# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask``
# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal``
src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0]
attn_mask = torch.empty((N, S, S), device=device).fill_(float("-inf"))
for i, s in enumerate(sentence_lengths):
attn_mask[i, :s, :s] = nn.Transformer.generate_square_subsequent_mask(s)
attn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N * nheads, S, S)
vanilla_mha_layer = torch.compile(vanilla_mha_layer)
# warmup compile
warmup_vanilla_result = vanilla_mha_layer(
padded_query,
padded_query,
padded_query,
attn_mask=attn_mask,
key_padding_mask=src_key_padding_mask,
need_weights=False,
is_causal=True,
)
# benchmark
(padded_result, _), padded_time, padded_peak_memory = benchmark(
vanilla_mha_layer,
padded_query,
padded_query,
padded_query,
key_padding_mask=src_key_padding_mask,
need_weights=False,
attn_mask=attn_mask,
is_causal=True,
)
print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB")
print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB")
print(
"Max difference between vanilla and nested result",
(padded_result - padded_nested_result).abs().max().item(),
)
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(
f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
)
Total sequence length in nested query 11128, max sequence length 148
/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/autograd_cache.py:542: UserWarning: NestedTensor does not implement _stable_hash_for_caching. For PT2-compatible tensor subclasses, it is recommended to implement _stable_hash_for_caching(self) -> str for stable AOT autograd caching.
warn_once(
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:320: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
padded_time=0.01992, padded_peak_memory=4.37 GB
nested_time=0.00249, nested_peak_memory=0.79 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 7.99
Nested peak memory reduction 3.58 GB
作为参考,以下是 A100 上的一些采样输出。
padded_time=0.03454, padded_peak_memory=4.14 GB
nested_time=0.00612, nested_peak_memory=0.76 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 5.65
Nested peak memory reduction 3.39 GB
我们也可以观察到反向传播中的相同情况。
for i, entry_length in enumerate(sentence_lengths):
# padding-specific step: remove output projection bias from padded entries for fair comparison
padded_result[i, entry_length:, :] = 0.0
_, padded_bw_time, padded_bw_peak_mem = benchmark(
lambda: padded_result.sum().backward()
)
_, nested_bw_time, nested_bw_peak_mem = benchmark(
lambda: padded_nested_result.sum().backward()
)
print(f"{padded_bw_time=:.5f}, padded_bw_peak_mem={padded_bw_peak_mem/1e9:.2f} GB")
print(f"{nested_bw_time=:.5f}, nested_bw_peak_mem={nested_bw_peak_mem/1e9:.2f} GB")
print(f"Nested backward speedup: {(padded_bw_time/nested_bw_time):.2f}")
print(
f"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB"
)
print(
"Difference in out_proj.weight.grad",
(mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad)
.abs()
.max()
.item(),
)
print(
"Difference in packed_proj.weight.grad",
(mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad)
.abs()
.max()
.item(),
)
print(
"Difference in out_proj.bias.grad",
(mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad)
.abs()
.max()
.item(),
)
print(
"Difference in packed_proj.bias.grad",
(mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad)
.abs()
.max()
.item(),
)
padded_bw_time=1.45551, padded_bw_peak_mem=5.36 GB
nested_bw_time=0.06629, nested_bw_peak_mem=3.40 GB
Nested backward speedup: 21.96
Nested backward peak memory reduction 1.96 GB
Difference in out_proj.weight.grad 0.0003070831298828125
Difference in packed_proj.weight.grad 0.00179290771484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.001953125
A100 上的采样输出
padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
Nested backward speedup: 144.13
Nested backward peak memory reduction 1.86 GB
Difference in out_proj.weight.grad 0.000244140625
Difference in packed_proj.weight.grad 0.001556396484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.001953125
GPT 风格层#
一个基本的 GPT 风格 Transformer 层由一个因果自注意力层组成,随后是一个带有跳跃连接的前馈网络 (FFN)。使用上述 MultiheadAttention 层实现这一点非常直接,并且给出的结果等同于 is_causal=True 的 nn.TransformerEncoderLayer。
我们在 此处 展示了实现其余 nn 层的示例,但为了简洁起见,本教程省略了这些内容。
更进一步#
到目前为止,我们已经演示了如何实现一个高性能的 MultiheadAttention 层,它遵循传统的 nn.MultiheadAttention。回到我们对 Transformer 架构修改的分类,请记住我们将修改分类为层类型、层排序和对注意力评分的修改。我们相信更改层类型和层排序(例如将 LayerNorm 替换为 RMSNorm)是非常直接的。
在本节中,我们将使用上述构建块讨论各种功能,包括:
交叉注意力 (Cross Attention)
完全被屏蔽的行不再导致 NaN
打包投影 (Packed Projection)
交叉注意力#
交叉注意力是一种注意力形式,其中查询和键/值张量来自不同的序列。
这方面的一个例子是 nn.TransformerDecoderLayer,其中查询来自解码器,而键/值来自编码器。
上述 MultiheadAttention 层通过对查询和键/值都使用嵌套张量,很好地推广到了这种情况。
query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)
print(
f"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}"
)
print(
f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}"
)
out = new_mha_layer(query, key, value, is_causal=False)
Total sequence length in nested query 10506, max sequence length 102
Total sequence length in nested key/value 10825, max sequence length 144
如上所述,我们可以将其与原生的编译版 nn.MultiheadAttention 进行比较。
torch.manual_seed(6)
query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)
padded_query, padded_key, padded_value = (
t.to_padded_tensor(0.0) for t in (query, key, value)
)
key_padding_mask = torch.where(padded_key == 0.0, -math.inf, 0)[:, :, 0]
# warmup compile
warmup_nested_result = new_mha_layer(query, key, value, is_causal=False)
warmup_vanilla_result = vanilla_mha_layer(
padded_query,
padded_key,
padded_value,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=False,
)
nested_result, nested_time, nested_peak_memory = benchmark(
new_mha_layer, query, key, value, is_causal=False
)
(padded_result, _), padded_time, padded_peak_memory = benchmark(
vanilla_mha_layer,
padded_query,
padded_key,
padded_value,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=False,
)
padded_nested_result = nested_result.to_padded_tensor(0.0)
for i, entry_length in enumerate(q_len):
# padding-specific step: remove output projection bias from padded entries for fair comparison
padded_result[i, entry_length:, :] = 0.0
print(
"Max difference between vanilla and nested result",
(padded_result - padded_nested_result).abs().max().item(),
)
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(
f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
)
Max difference between vanilla and nested result 1.5497207641601562e-06
Nested speedup: 6.15
Nested peak memory reduction 1.31 GB
A100 上的采样输出
Max difference between vanilla and nested result 0.0
Nested speedup: 4.01
Nested peak memory reduction 1.40 GB
完全被屏蔽的行不再导致 NaN#
nn.MultiheadAttention 和 scaled_dot_product_attention 长期以来存在一个问题:如果一行被完全掩码屏蔽,注意力层的输出将变为 NaN。参见 issue。这是因为空集上的 softmax 是未定义的。
多亏了 这个 PR,现在情况不再如此。相反,scaled_dot_product_attention 中完全被掩码的行所对应的输出将为 0。对于 nn.MHA 不使用“快速路径(fast-path)”的情况,这也适用。
强烈建议使用带有 NJT 的自定义 MHA 层,而不是 nn.MultiheadAttention 中现有的“快速路径”,因为 NJT 对参差不齐序列的建模能力使得能够正确地表达空序列。
打包投影#
打包投影是一种利用以下事实的技术:当投影(矩阵乘法)的输入相同时(自注意力),我们可以将投影权重和偏差打包成单个张量。当单个投影受内存限制而非计算限制时,这种方法特别有用。我们将在此演示两个示例:
MultiheadAttention 的输入投影
Transformer 层前馈网络中的 SwiGLU 激活
MultiheadAttention 的输入投影#
在进行自注意力时,query、key 和 value 是同一个张量。这些张量中的每一个都通过一个 Linear(E_q, E_total) 层进行投影。相反,我们可以将其打包成一个层,这正是我们在上面的 MultiheadAttention 层中所做的。
让我们比较打包投影与通常方法的性能。
class InputProjection(nn.Module):
def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
def forward(self, x):
return self.q_proj(x), self.k_proj(x), self.v_proj(x)
class PackedInputProjection(nn.Module):
def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)
def forward(self, query):
return torch.chunk(self.packed_proj(query), 3, dim=-1)
B, D, dtype = 256, 8192, torch.bfloat16
torch.set_float32_matmul_precision("high")
in_proj = torch.compile(InputProjection(D, D, device="cuda", dtype=torch.bfloat16))
packed_in_proj = torch.compile(
PackedInputProjection(D, D, device="cuda", dtype=torch.bfloat16)
)
q, _, _, sequence_lengths = gen_batch(B, D, D, D, device="cuda", dtype=torch.bfloat16)
# warmup
in_proj(q)
packed_in_proj(q)
# benchmark
(q_out, k_out, v_out), time, _ = benchmark(in_proj, q)
(q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q)
# On my A100 prints 1.05x speedup
print(
f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x"
)
InputProjection: 0.035766 s, PackedInputProjection: 0.035048 s, speedup: 1.02x
Transformer 层的前馈网络 SwiGLU#
Swish-Gated Linear Unit (SwiGLU) 是一种非线性激活函数,在 Transformer 层的前馈网络中越来越受欢迎(例如 Llama)。带有 SwiGLU 激活的前馈网络定义为:
class SwiGLUFFN(nn.Module):
def __init__(
self,
dim,
hidden_dim,
multiple_of,
ffn_dim_multiplier=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)
self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)
self.w3 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
使用打包投影实现此功能的另一种方法是:
class PackedSwiGLUFFN(nn.Module):
def __init__(
self,
dim,
hidden_dim,
multiple_of,
ffn_dim_multiplier=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False, **factory_kwargs)
self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)
def forward(self, x):
x1, x3 = torch.chunk(self.w13(x), 2, dim=-1)
return self.w2(F.silu(x1) * x3)
我们可以按如下方式比较两种实现的性能。根据您的硬件,可能会看到不同的结果。在 A100 上,对于 D=128,我看到了 1.12 倍的加速。
D = 128
swigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16))
packed_swigluffn = torch.compile(
PackedSwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16)
)
q, _, _, sentence_lengths = gen_batch(D, D, D, D, device="cuda", dtype=torch.bfloat16)
# warmup
swigluffn(q)
packed_swigluffn(q)
# benchmark
_, time, _ = benchmark(swigluffn, q)
_, time_packed, _ = benchmark(packed_swigluffn, q)
# On my A100 prints 1.08x speedup
print(
f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x"
)
SwiGLUFFN: 0.0008703010003046074 s, PackedSwiGLUFFN: 0.0007882889999564213 s, speedup: 1.10x
扩展示例#
我们打算更新本教程,以演示如何使用各种高性能构建块(如 KV-Caching、Grouped Query Attention 等)的更多示例。此外,还有许多使用各种高性能构建块来实现不同 Transformer 架构的优秀示例。一些示例包括:
结论#
在本教程中,我们介绍了 PyTorch 为编写 Transformer 层提供的底层构建块,并演示了如何组合它们的示例。我们希望本教程能让读者了解 PyTorch 用户可以多么轻松地实现灵活且高性能的 Transformer 层。
脚本总运行时间: (0 分钟 16.645 秒)