torch.nn.functional.scaled_dot_product_attention#
- torch.nn.functional.scaled_dot_product_attention()#
- scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> Tensor
计算 query、key 和 value 张量的缩放点积注意力,如果提供了 attention mask,则使用它,如果指定了大于 0.0 的概率,则应用 dropout。可选的 scale 参数只能作为关键字参数指定。
# Efficient implementation equivalent to the following: def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias = attn_mask + attn_bias if enable_gqa: key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value
警告
此函数处于 beta 阶段,可能会发生更改。
警告
此函数始终根据指定的
dropout_p
参数应用 dropout。要在评估期间禁用 dropout,请确保在调用该函数的模块不在训练模式下时将0.0
传递给它。例如
class MyModel(nn.Module): def __init__(self, p=0.5): super().__init__() self.p = p def forward(self, ...): return F.scaled_dot_product_attention(..., dropout_p=(self.p if self.training else 0.0))
注意
目前支持三种缩放点积注意力实现:
C++ 实现的 PyTorch 版本,匹配上述公式
在使用 CUDA 后端时,该函数可能会调用优化内核以提高性能。对于所有其他后端,将使用 PyTorch 实现。
所有实现默认都已启用。缩放点积注意力会尝试根据输入自动选择最优的实现。为了对使用的实现进行更精细化的控制,提供了以下函数来启用和禁用实现。上下文管理器是首选机制。
torch.nn.attention.sdpa_kernel()
: 一个上下文管理器,用于启用或禁用任何实现。torch.backends.cuda.enable_flash_sdp()
: 全局启用或禁用 FlashAttention。torch.backends.cuda.enable_mem_efficient_sdp()
: 全局启用或禁用内存高效注意力。torch.backends.cuda.enable_math_sdp()
: 全局启用或禁用 PyTorch C++ 实现。
每个融合内核都有特定的输入限制。如果用户需要使用特定的融合实现,请使用
torch.nn.attention.sdpa_kernel()
禁用 PyTorch C++ 实现。如果融合实现不可用,将引发警告,说明无法运行融合实现的原因。由于浮点运算融合的性质,此函数的输出可能因选择的后端内核而异。C++ 实现支持 torch.float64,在需要更高精度时可以使用。对于 math 后端,如果输入为 torch.half 或 torch.bfloat16,则所有中间值都保留为 torch.float。
更多信息请参阅 数值精度。
分组查询注意力 (GQA) 是一项实验性功能。目前它仅适用于 CUDA 张量上的 Flash_attention 和 math 内核,不支持 Nested tensor。GQA 的约束条件:
number_of_heads_query % number_of_heads_key_value == 0 且,
number_of_heads_key == number_of_heads_value
注意
在某些情况下,当在 CUDA 设备上使用张量并利用 CuDNN 时,此算子可能会选择一个非确定性算法来提高性能。如果这不可取,你可以尝试将操作设置为确定性的(可能以性能为代价),方法是设置
torch.backends.cudnn.deterministic = True
。有关更多信息,请参阅 可复现性。- 参数
query (Tensor) – 查询张量;形状为 。
key (Tensor) – 键张量;形状为 。
value (Tensor) – 值张量;形状为 。
attn_mask (optional Tensor) – 注意力掩码;形状必须可广播到注意力权重形状,即 。支持两种类型的掩码。布尔掩码,其中 True 值表示该元素应该参与注意力。与 query、key、value 类型相同的浮点掩码,该掩码会加到注意力分数上。
dropout_p (float) – Dropout 概率;如果大于 0.0,则应用 dropout。
is_causal (bool) – 如果设置为 true,则在掩码为方阵时,注意力掩码为下三角矩阵。当掩码为非方阵时,注意力掩码的形式是由于对齐而产生的左上因果偏差(参见
torch.nn.attention.bias.CausalBias
)。如果同时设置了 attn_mask 和 is_causal,则会引发错误。scale (optional python:float, keyword-only) – Softmax 之前应用的缩放因子。如果为 None,则默认值为 。
enable_gqa (bool) – 如果设置为 True,则启用分组查询注意力 (GQA),默认设置为 False。
- 返回
注意力输出;形状为 。
- 返回类型
output (Tensor)
- 形状说明
示例
>>> # Optionally use the context manager to ensure one of the fused kernels is run >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): >>> F.scaled_dot_product_attention(query,key,value)
>>> # Sample for GQA for llama3 >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda") >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> with sdpa_kernel(backends=[SDPBackend.MATH]): >>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True)