MultiheadAttention#
- class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source]#
允许模型联合关注来自不同表示子空间的信息。
此 MultiheadAttention 层实现了 Attention Is All You Need 论文中描述的原始架构。此层的目的是作为基础理解的参考实现,因此它相对于较新的架构仅包含有限的功能。鉴于 Transformer 类架构的快速创新,我们建议探索此 教程,以从核心的构建块中构建高效的层,或使用 PyTorch 生态系统 中的更高级库。
多头注意力定义为
其中 .
nn.MultiheadAttention
在可能的情况下将使用scaled_dot_product_attention()
的优化实现。除了支持新的
scaled_dot_product_attention()
函数外,为了加速推理,MHA 将使用 fastpath 推理,并支持 Nested Tensors,当且仅当:计算自注意力(即
query
、key
和value
是同一个张量)。输入是批处理的(3D)并且
batch_first==True
autograd 被禁用(使用
torch.inference_mode
或torch.no_grad
)或没有张量参数requires_grad
training
被禁用(使用.eval()
)add_bias_kv
为False
add_zero_attn
为False
kdim
和vdim
等于embed_dim
如果传递了 NestedTensor,则既不传递
key_padding_mask
也不传递attn_mask
autocast 被禁用
如果使用了优化的推理 fastpath 实现,则可以为
query
/key
/value
传递 NestedTensor 来更有效地表示 padding,而不是使用 padding mask。在这种情况下,将返回一个 NestedTensor,并且可以预期速度会提高,具体取决于输入中 padding 的比例。- 参数
embed_dim – 模型的总维度。
num_heads – 并行注意力头的数量。请注意,
embed_dim
将被分割到num_heads
中(即每个头的维度为embed_dim // num_heads
)。dropout –
attn_output_weights
上的 dropout 概率。默认值:0.0
(无 dropout)。bias – 如果指定,则在输入/输出投影层添加偏置。默认值:
True
。add_bias_kv – 如果指定,则在
dim=0
处向 key 和 value 序列添加偏置。默认值:False
。add_zero_attn – 如果指定,则在
dim=1
处向 key 和 value 序列添加新的零批次。默认值:False
。kdim – key 的总特征数。默认值:
None
(使用kdim=embed_dim
)。vdim – value 的总特征数。默认值:
None
(使用vdim=embed_dim
)。batch_first – 如果为
True
,则输入和输出张量表示为 (batch, seq, feature)。默认值:False
(seq, batch, feature)。
示例
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source]#
使用 query、key 和 value 嵌入计算注意力输出。
支持用于 padding、mask 和注意力权重的可选参数。
- 参数
query (Tensor) – 对于未批处理的输入,query 嵌入的形状为 ,当
batch_first=False
时为 ,当batch_first=True
时为 ,其中 是目标序列长度, 是批次大小, 是 query 嵌入维度embed_dim
。Queries 与 key-value 对进行比较以生成输出。更多详细信息请参阅“Attention Is All You Need”。key (Tensor) – 对于未批处理的输入,key 嵌入的形状为 ,当
batch_first=False
时为 ,当batch_first=True
时为 ,其中 是源序列长度, 是批次大小, 是 key 嵌入维度kdim
。更多详细信息请参阅“Attention Is All You Need”。value (Tensor) – 对于未批处理的输入,value 嵌入的形状为 ,当
batch_first=False
时为 ,当batch_first=True
时为 ,其中 是源序列长度, 是批次大小, 是 value 嵌入维度vdim
。更多详细信息请参阅“Attention Is All You Need”。key_padding_mask (Optional[Tensor]) – 如果指定,形状为 的 mask,用于指示
key
中哪些元素应被忽略(即被视为“padding”)。对于未批处理的 query,形状应为 。支持二进制和浮点 masks。对于二进制 mask,True
值表示相应的key
值将被忽略。对于浮点 mask,它将直接添加到相应的key
值中。need_weights (bool) – 如果指定,除了
attn_outputs
外,还将返回attn_output_weights
。将need_weights=False
以使用优化的scaled_dot_product_attention
并获得 MHA 的最佳性能。默认值:True
。attn_mask (Optional[Tensor]) – 如果指定,一个 2D 或 3D mask,用于阻止对某些位置的注意力。形状必须为 或 ,其中 是批次大小, 是目标序列长度, 是源序列长度。2D mask 将被广播到批次,而 3D mask 则允许批次中的每个条目都有不同的 mask。支持二进制和浮点 masks。对于二进制 mask,
True
值表示不允许对相应位置进行注意力。对于浮点 mask,mask 值将添加到注意力权重中。如果同时提供了 attn_mask 和 key_padding_mask,它们的类型应匹配。average_attn_weights (bool) – 如果为 true,表示返回的
attn_weights
应在 heads 之间平均。否则,attn_weights
将为每个 head 单独提供。请注意,此标志仅在need_weights=True
时有效。默认值:True
(即在 heads 之间平均权重)is_causal (bool) – 如果指定,则将因果 mask 作为注意力 mask 应用。默认值:
False
。警告:is_causal
提供了attn_mask
是因果 mask 的提示。提供不正确的提示可能导致执行不正确,包括向前和向后兼容性。
- 返回类型
- 输出
attn_output - 注意力输出,当输入未批处理时形状为 ,当
batch_first=False
时形状为 ,或当batch_first=True
时形状为 ,其中 是目标序列长度, 是批次大小, 是嵌入维度embed_dim
。attn_output_weights - 仅当
need_weights=True
时返回。如果average_attn_weights=True
,返回跨头的平均注意力权重,当输入未批处理时形状为 ,或当输入已批处理时形状为 ,其中 是批次大小, 是目标序列长度, 是源序列长度。如果average_attn_weights=False
,返回每个头的注意力权重,当输入未批处理时形状为 ,或当输入已批处理时形状为 。
注意
batch_first 参数对于未批处理的输入将被忽略。
- merge_masks(attn_mask, key_padding_mask, query)[来源]#
确定掩码类型并合并掩码(如果需要)。
如果仅提供一个掩码,则返回该掩码和对应的掩码类型。如果同时提供两个掩码,它们将被扩展为形状
(batch_size, num_heads, seq_len, seq_len)
,然后通过逻辑or
组合,并返回掩码类型 2::param attn_mask: 注意力掩码,形状为(seq_len, seq_len)
,掩码类型 0 :param key_padding_mask: 填充掩码,形状为(batch_size, seq_len)
,掩码类型 1 :param query: 查询嵌入,形状为(batch_size, seq_len, embed_dim)
- 返回
merged mask mask_type: 合并后的掩码类型(0、1 或 2)
- 返回类型
merged_mask