评价此页

MultiheadAttention#

class torch.nn.modules.activation.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[源代码]#

允许模型联合关注来自不同表示子空间的信息。

此 MultiheadAttention 层实现了 Attention Is All You Need 论文中描述的原始架构。此层的目的是作为基础理解的参考实现,因此它仅包含相对于较新架构的有限功能。鉴于 Transformer 类架构的快速创新步伐,我们建议探索此 教程,从核心构建块构建高效层,或使用 PyTorch 生态系统 中的更高级库。

多头注意力定义为

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O

其中 headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

nn.MultiheadAttention 在可能的情况下将使用 scaled_dot_product_attention() 的优化实现。

除了支持新的 scaled_dot_product_attention() 函数外,为了加速推理,MHA 将使用快速路径推理,并支持 Nested Tensors,前提是:

  • 正在计算自注意力(即 querykeyvalue 是同一个张量)。

  • 输入是批处理的(3D),并且 batch_first==True

  • 自动梯度被禁用(使用 torch.inference_modetorch.no_grad)或者没有张量参数 requires_grad

  • 训练被禁用(使用 .eval()

  • add_bias_kvFalse

  • add_zero_attnFalse

  • kdimvdim 等于 embed_dim

  • 如果传递了 NestedTensor,则不传递 key_padding_maskattn_mask

  • autocast 被禁用。

如果正在使用优化的推理快速路径实现,则可以将 NestedTensor 传递给 query/key/value,以比使用 padding mask 更有效地表示 padding。在这种情况下,将返回一个 NestedTensor,并且可以预期相对于输入中 padding 的比例有额外的加速。

参数
  • embed_dim – 模型的总维度。

  • num_heads – 并行注意力头的数量。请注意,embed_dim 将被分割给 num_heads(即每个头的维度将是 embed_dim // num_heads)。

  • dropoutattn_output_weights 上的 Dropout 概率。默认为 0.0(无 dropout)。

  • bias – 如果指定,则为输入/输出投影层添加偏置。默认为 True

  • add_bias_kv – 如果指定,则在 dim=0 处为 key 和 value 序列添加偏置。默认为 False

  • add_zero_attn – 如果指定,则在 dim=1 处为 key 和 value 序列添加新的零批次。默认为 False

  • kdim – key 的总特征数。默认为 None(使用 kdim=embed_dim)。

  • vdim – value 的总特征数。默认为 None(使用 vdim=embed_dim)。

  • batch_first – 如果为 True,则输入和输出张量为 (batch, seq, feature)。默认为 False(seq, batch, feature)。

示例

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[源代码]#

使用 query、key 和 value 嵌入计算注意力输出。

支持 padding、mask 和注意力权重的可选参数。

参数
  • query (Tensor) – shape 为 (L,Eq)(L, E_q) 的未批处理输入,当 batch_first=False 时 shape 为 (L,N,Eq)(L, N, E_q) 或当 batch_first=True 时 shape 为 (N,L,Eq)(N, L, E_q),其中 LL 是目标序列长度,NN 是批次大小,EqE_q 是 query 嵌入维度 embed_dim。Query 与 key-value 对进行比较以生成输出。更多细节请参见“Attention Is All You Need”。

  • key (Tensor) – shape 为 (S,Ek)(S, E_k) 的未批处理输入,当 batch_first=False 时 shape 为 (S,N,Ek)(S, N, E_k) 或当 batch_first=True 时 shape 为 (N,S,Ek)(N, S, E_k),其中 SS 是源序列长度,NN 是批次大小,EkE_k 是 key 嵌入维度 kdim。更多细节请参见“Attention Is All You Need”。

  • value (Tensor) – shape 为 (S,Ev)(S, E_v) 的未批处理输入,当 batch_first=False 时 shape 为 (S,N,Ev)(S, N, E_v) 或当 batch_first=True 时 shape 为 (N,S,Ev)(N, S, E_v),其中 SS 是源序列长度,NN 是批次大小,EvE_v 是 value 嵌入维度 vdim。更多细节请参见“Attention Is All You Need”。

  • key_padding_mask (Optional[Tensor]) – 如果指定,则为 shape 为 (N,S)(N, S) 的掩码,指示 key 中哪些元素应被忽略(即被视为“padding”)。对于未批处理的 query,shape 应为 (S)(S)。支持二进制和浮点掩码。对于二进制掩码,True 值表示相应的 key 值将被忽略。对于浮点掩码,它将直接添加到相应的 key 值中。

  • need_weights (bool) – 如果指定,除了 attn_outputs 外,还将返回 attn_output_weights。将 need_weights=False 以使用优化的 scaled_dot_product_attention 并为 MHA 获得最佳性能。默认为 True

  • attn_mask (Optional[Tensor]) – 如果指定,则为 2D 或 3D 掩码,可防止注意力集中到某些位置。必须是 shape (L,S)(L, S)(Nnum_heads,L,S)(N\cdot\text{num\_heads}, L, S),其中 NN 是批次大小,LL 是目标序列长度,SS 是源序列长度。2D 掩码将广播到整个批次,而 3D 掩码允许批次中的每个条目都有不同的掩码。支持二进制和浮点掩码。对于二进制掩码,True 值表示相应的位置不允许注意力。对于浮点掩码,掩码值将加到注意力权重上。如果同时提供了 attn_mask 和 key_padding_mask,它们的类型应匹配。

  • average_attn_weights (bool) – 如果为 True,则表示返回的 attn_weights 应在 heads 之间取平均值。否则,attn_weights 将按 head 分别提供。请注意,此标志仅在 need_weights=True 时生效。默认为 True(即在 heads 之间平均权重)。

  • is_causal (bool) – 如果指定,则将因果掩码作为注意力掩码应用。默认为 False。警告:is_causal 提供了一个 attn_mask 是因果掩码的提示。提供不正确的提示可能导致执行错误,包括向前和向后兼容性。

返回类型

tuple[torch.Tensor, Optional[torch.Tensor]]

输出
  • attn_output - 注意力输出,shape 为 (L,E)(L, E)(当输入未批处理时),(L,N,E)(L, N, E)(当 batch_first=False 时)或 (N,L,E)(N, L, E)(当 batch_first=True 时),其中 LL 是目标序列长度,NN 是批次大小,EE 是嵌入维度 embed_dim

  • attn_output_weights – 仅当 need_weights=True 时返回。如果 average_attn_weights=True,则返回在 heads 之间平均后的注意力权重,shape 为 (L,S)(L, S)(当输入未批处理时)或 (N,L,S)(N, L, S)(当 batch_first=False 时),其中 NN 是批次大小,LL 是目标序列长度,SS 是源序列长度。如果 average_attn_weights=False,则返回按 head 分布的注意力权重,shape 为 (num_heads,L,S)(\text{num\_heads}, L, S)(当输入未批处理时)或 (N,num_heads,L,S)(N, \text{num\_heads}, L, S)(当输入批处理时)。

注意

batch_first 参数对于未批处理的输入将被忽略。

merge_masks(attn_mask, key_padding_mask, query)[源代码]#

确定掩码类型并合并掩码(如果需要)。

如果只提供一个掩码,将返回该掩码和相应的掩码类型。如果同时提供两个掩码,它们都将被扩展到 shape (batch_size, num_heads, seq_len, seq_len),并使用逻辑 or 合并,并将返回掩码类型 2 :param attn_mask: shape 为 (seq_len, seq_len) 的注意力掩码,掩码类型 0 :param key_padding_mask: shape 为 (batch_size, seq_len) 的 padding 掩码,掩码类型 1 :param query: shape 为 (batch_size, seq_len, embed_dim) 的 query 嵌入

返回

merged mask mask_type: 合并后的掩码类型(0、1 或 2)

返回类型

merged_mask