torch.nn.attention.flex_attention#
创建于: 2024年7月16日 | 最后更新于: 2025年9月8日
- torch.nn.attention.flex_attention.flex_attention(query, key, value, score_mod=None, block_mask=None, scale=None, enable_gqa=False, return_lse=False, kernel_options=None, *, return_aux=None)[source]#
该函数实现了具有任意注意力分数修改函数的缩放点积注意力。
该函数在查询、键和值张量之间计算缩放点积注意力,并使用用户定义的注意力分数修改函数。注意力分数修改函数将在查询和键张量之间的注意力分数计算完成后应用。注意力分数的计算方式如下:
score_mod
函数应具有以下签名:def score_mod( score: Tensor, batch: Tensor, head: Tensor, q_idx: Tensor, k_idx: Tensor ) -> Tensor:
- 其中
score
:一个标量张量,表示注意力分数,其数据类型和设备与查询、键和值张量相同。batch
、head
、q_idx
、k_idx
:标量张量,分别指示批次索引、查询头索引、查询索引和键/值索引。这些应具有torch.int
数据类型,并位于与分数张量相同的设备上。
- 参数
query (Tensor) – 查询张量;形状为 。对于 FP8 数据类型,应采用行主内存布局以获得最佳性能。
key (Tensor) – 键张量;形状为 。对于 FP8 数据类型,应采用行主内存布局以获得最佳性能。
value (Tensor) – 值张量;形状为 。对于 FP8 数据类型,应采用列主内存布局以获得最佳性能。
score_mod (Optional[Callable]) – 用于修改注意力分数的函数。默认情况下,不应用 score_mod。
block_mask (Optional[BlockMask]) – BlockMask 对象,用于控制注意力的块稀疏性模式。
scale (Optional[float]) – 在 softmax 之前应用的缩放因子。如果为 None,则默认值为 。
enable_gqa (bool) – 如果设置为 True,则启用分组查询注意力(GQA)并向查询头广播键/值头。
return_lse (bool) – 是否返回注意力分数的对数和(logsumexp)。默认为 False。已弃用:请改用
return_aux=AuxRequest(lse=True)
。kernel_options (Optional[FlexKernelOptions]) – 用于控制底层 Triton 内核行为的选项。有关可用选项和用法示例,请参阅
FlexKernelOptions
。return_aux (Optional[AuxRequest]) – 指定要计算和返回的辅助输出。如果为 None,则只返回注意力输出。使用
AuxRequest(lse=True, max_scores=True)
来请求两个辅助输出。
- 返回
注意力输出;形状为 。
- 当
return_aux
不为 None 时 aux (AuxOutput): 包含已请求字段的辅助输出。
- 当
return_aux
为 None 时(已弃用路径) lse (Tensor): 注意力分数的对数和;形状为 。仅当
return_lse=True
时返回。
- 当
- 返回类型
output (Tensor)
- 形状说明
警告
torch.nn.attention.flex_attention 是 PyTorch 中的一个原型功能。请期待 PyTorch 未来版本中更稳定的实现。有关功能分类的更多信息,请访问:https://pytorch.ac.cn/blog/pytorch-feature-classification-changes/#prototype
- class torch.nn.attention.flex_attention.AuxOutput(lse=None, max_scores=None)[source]#
flex_attention 操作的辅助输出。
如果未请求,字段将为 None;如果已请求,则包含张量。
- class torch.nn.attention.flex_attention.AuxRequest(lse=False, max_scores=False)[source]#
请求从 flex_attention 计算哪些辅助输出。
每个字段都是一个布尔值,指示是否应计算该辅助输出。
BlockMask 工具#
- torch.nn.attention.flex_attention.create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device='cuda', BLOCK_SIZE=128, _compile=False)[source]#
此函数从 mask_mod 函数创建块掩码元组。
- 参数
mask_mod (Callable) – mask_mod 函数。这是一个可调用对象,用于定义注意力机制的掩码模式。它接受四个参数:b(批次大小)、h(头数)、q_idx(查询索引)和 kv_idx(键/值索引)。它应返回一个布尔张量,指示哪些注意力连接是允许的(True)或被掩码掉的(False)。
B (int) – 批次大小。
H (int) – 查询头数。
Q_LEN (int) – 查询的序列长度。
KV_LEN (int) – 键/值的序列长度。
device (str) – 用于运行掩码创建的设备。
BLOCK_SIZE (int 或 tuple[int, int]) – 块掩码的块大小。如果提供单个整数,则同时用于查询和键/值。
- 返回
一个 BlockMask 对象,其中包含块掩码信息。
- 返回类型
- 示例用法
def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda") query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) output = flex_attention(query, key, value, block_mask=block_mask)
- torch.nn.attention.flex_attention.create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device='cuda')[source]#
此函数从 mod_fn 函数创建掩码张量。
- torch.nn.attention.flex_attention.create_nested_block_mask(mask_mod, B, H, q_nt, kv_nt=None, BLOCK_SIZE=128, _compile=False)[source]#
此函数从 mask_mod 函数创建与嵌套张量兼容的块掩码元组。返回的 BlockMask 将位于输入嵌套张量指定的设备上。
- 参数
mask_mod (Callable) – mask_mod 函数。这是一个可调用对象,用于定义注意力机制的掩码模式。它接受四个参数:b(批次大小)、h(头数)、q_idx(查询索引)和 kv_idx(键/值索引)。它应返回一个布尔张量,指示哪些注意力连接是允许的(True)或被掩码掉的(False)。
B (int) – 批次大小。
H (int) – 查询头数。
q_nt (torch.Tensor) – 锯齿状布局嵌套张量(NJT),用于定义查询的序列长度结构。块掩码将构造为作用于 NJT 中序列长度
S
的“堆叠序列”的长度sum(S)
。kv_nt (torch.Tensor) – 锯齿状布局嵌套张量(NJT),用于定义键/值的序列长度结构,允许交叉注意力。块掩码将构造为作用于 NJT 中序列长度
S
的“堆叠序列”的长度sum(S)
。如果此参数为 None,则q_nt
也将用于定义键/值的结构。默认为 NoneBLOCK_SIZE (int 或 tuple[int, int]) – 块掩码的块大小。如果提供单个整数,则同时用于查询和键/值。
- 返回
一个 BlockMask 对象,其中包含块掩码信息。
- 返回类型
- 示例用法
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_nested_block_mask( causal_mask, 1, 1, query, _compile=True ) output = flex_attention(query, key, value, block_mask=block_mask)
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx # cross attention case: pass both query and key/value NJTs block_mask = create_nested_block_mask( causal_mask, 1, 1, query, key, _compile=True ) output = flex_attention(query, key, value, block_mask=block_mask)
FlexKernelOptions#
- class torch.nn.attention.flex_attention.FlexKernelOptions[source]#
FlexAttention 内核的行为控制选项。
这些选项将传递给底层 Triton 内核,以控制性能和数值行为。大多数用户不需要指定这些选项,因为默认的自动调整提供了良好的性能。
选项可以加上
fwd_
或bwd_
前缀,以便分别仅应用于前向或后向传递。例如:fwd_BLOCK_M
和bwd_BLOCK_M1
。注意
目前我们不为这些选项提供任何向后兼容性保证。尽管如此,自引入以来,其中大部分选项都相当稳定。但我们暂时不认为这是公共 API 的一部分。我们认为文档比隐藏的秘密标志更好,但我们将来可能会更改这些选项。
- 示例用法
# Using dictionary (backward compatible) kernel_opts = {"BLOCK_M": 64, "BLOCK_N": 64, "PRESCALE_QK": True} output = flex_attention(q, k, v, kernel_options=kernel_opts) # Using TypedDict (recommended for type safety) from torch.nn.attention.flex_attention import FlexKernelOptions kernel_opts: FlexKernelOptions = { "BLOCK_M": 64, "BLOCK_N": 64, "PRESCALE_QK": True, } output = flex_attention(q, k, v, kernel_options=kernel_opts) # Forward/backward specific options kernel_opts: FlexKernelOptions = { "fwd_BLOCK_M": 64, "bwd_BLOCK_M1": 32, "PRESCALE_QK": False, } output = flex_attention(q, k, v, kernel_options=kernel_opts)
- BLOCKS_ARE_CONTIGUOUS: NotRequired[bool]#
如果为 True,则保证掩码中的所有块都是连续的。允许优化块遍历。例如,因果掩码会满足此条件,但前缀 LM + 滑动窗口则不会。默认为 False。
- FORCE_USE_FLEX_ATTENTION: NotRequired[bool]#
如果为 True,则强制使用 flex attention 内核,而不是可能为短序列使用更优化的 flex-decoding 内核。这对于调试来说是一个有用的选项。默认为 False。
- ROWS_GUARANTEED_SAFE: NotRequired[bool]#
如果为 True,则保证每行至少有一个值未被掩码掉。允许跳过安全检查以获得更好的性能。只有当您确定掩码保证此属性时才设置此项。例如,因果注意力被保证是安全的,因为每个查询至少有 1 个键-值可以关注。默认为 False。
- USE_TMA: NotRequired[bool]#
是否在支持的硬件上使用 Tensor Memory Accelerator (TMA)。这处于实验阶段,可能无法在所有硬件上运行,目前仅限于 NVIDIA GPU Hopper+。默认为 False。
BlockMask#
- class torch.nn.attention.flex_attention.BlockMask(seq_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, q_num_blocks, q_indices, full_q_num_blocks, full_q_indices, BLOCK_SIZE, mask_mod)[source]#
BlockMask 是我们用于表示块稀疏注意力掩码的格式。它在某种程度上介于 BCSR 和非稀疏格式之间。
基础知识
块稀疏掩码意味着,与其表示掩码中单个元素的稀疏性,不如将 KV_BLOCK_SIZE x Q_BLOCK_SIZE 块视为稀疏,仅当该块内的每个元素都稀疏时。这与硬件的期望非常吻合,硬件通常期望进行连续的加载和计算。
此格式主要针对 1. 简单性;2. 内核效率进行了优化。值得注意的是,它 *不* 针对大小进行优化,因为此掩码的大小总是除以 KV_BLOCK_SIZE * Q_BLOCK_SIZE。如果大小是问题,可以通过增加块大小来减小张量的大小。
我们格式的关键点是:
num_blocks_in_row: Tensor[ROWS]: 描述每行中存在的块数。
col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]:
col_indices[i]
是第 i 行的块位置序列。此行中col_indices[i][num_blocks_in_row[i]]
之后的值未定义。例如,要从该格式中恢复原始张量:
dense_mask = torch.zeros(ROWS, COLS) for row in range(ROWS): for block_idx in range(num_blocks_in_row[row]): dense_mask[row, col_indices[row, block_idx]] = 1
值得注意的是,此格式使得沿着掩码的*行*进行归约操作更容易。
详细信息
我们格式的基本要求是仅 kv_num_blocks 和 kv_indices。但是,我们在此对象上有多达 8 个张量。这代表 4 对:
1. (kv_num_blocks, kv_indices): 用于注意力的前向传递,因为我们沿着 KV 维度进行归约。
2. [可选] (full_kv_num_blocks, full_kv_indices): 这是可选的,纯粹是为了优化。事实证明,对每个块应用掩码成本很高!如果我们特别知道哪些块是“完整的”并且不需要应用掩码,那么我们可以跳过将 mask_mod 应用于这些块。这要求用户将 mask_mod 分开,而不是从 score_mod 中分离。对于因果掩码,这可以带来约 15% 的速度提升。
3. [生成] (q_num_blocks, q_indices): 后向传递需要,因为计算 dKV 需要沿着 Q 维度沿掩码进行迭代。这些是根据 1 自动生成的。
4. [生成] (full_q_num_blocks, full_q_indices): 与上面相同,但用于后向传递。这些是根据 2 自动生成的。
- as_tuple(flatten=True)[source]#
返回 BlockMask 属性的元组。
- 参数
flatten (bool) – 如果为 True,则将 (KV_BLOCK_SIZE, Q_BLOCK_SIZE) 的元组展平。
- classmethod from_kv_blocks(kv_num_blocks, kv_indices, full_kv_num_blocks=None, full_kv_indices=None, BLOCK_SIZE=128, mask_mod=None, seq_lengths=None, compute_q_blocks=True)[source]#
从键值块信息创建 BlockMask 实例。
- 参数
kv_num_blocks (Tensor) – 每个 Q_BLOCK_SIZE 行块的 kv_blocks 数量。
kv_indices (Tensor) – 每个 Q_BLOCK_SIZE 行块的键值块索引。
full_kv_num_blocks (Optional[Tensor]) – 每个 Q_BLOCK_SIZE 行块中的完整 kv_blocks 数量。
full_kv_indices (Optional[Tensor]) – 每个 Q_BLOCK_SIZE 行块中的完整键值块索引。
BLOCK_SIZE (Union[int, tuple[int, int]]) – KV_BLOCK_SIZE x Q_BLOCK_SIZE 块的大小。
mask_mod (Optional[Callable]) – 用于修改掩码的函数。
- 返回
通过 _transposed_ordered 生成完整 Q 信息的实例。
- 返回类型
- 引发
RuntimeError – 如果 kv_indices 的维度小于 2。
AssertionError – 如果只提供 full_kv_* 参数中的一个。
- property shape#
- to(device)[source]#
将 BlockMask 移动到指定的设备。
- 参数
device (torch.device 或 str) – 要将 BlockMask 移动到的目标设备。可以是 torch.device 对象或字符串(例如,‘cpu’、‘cuda:0’)。
- 返回
一个将所有张量组件移动到指定设备的新 BlockMask 实例。
- 返回类型
注意
此方法不会就地修改原始 BlockMask。相反,它返回一个新的 BlockMask 实例,其中各个张量属性可能会或可能不会移动到指定设备,具体取决于它们当前的设备放置。