MultiheadAttention#
- class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[源]#
允许模型联合关注来自不同表示子空间的信息。
此 MultiheadAttention 层实现了 Attention Is All You Need 论文中描述的原始架构。该层旨在作为基础理解的参考实现,因此其功能相对于较新的架构而言仅限于有限的功能。鉴于 Transformer 类架构的快速创新步伐,我们建议您探索此 教程,以从核心构建块构建高效层,或使用 PyTorch 生态系统中的更高级库。
多头注意力 (Multi-Head Attention) 定义为:
其中 。
nn.MultiheadAttention将在可能的情况下使用scaled_dot_product_attention()的优化实现。除了支持新的
scaled_dot_product_attention()函数外,为了加速推理,MHA 还将使用支持 Nested Tensors 的 fastpath 推理,前提是:计算自注意力(即
query、key和value是同一个张量)。输入是批处理的(3D),并且
batch_first==True。自动梯度被禁用(使用
torch.inference_mode或torch.no_grad)或者没有张量参数requires_grad训练被禁用(使用
.eval())add_bias_kv为False。add_zero_attn为False。kdim和vdim等于embed_dim。如果传递了 NestedTensor,则不传递
key_padding_mask和attn_mask。autocast 已禁用。
如果使用了优化的推理 fastpath 实现,则可以为
query/key/value传递 NestedTensor,以比使用 padding mask 更有效地表示 padding。在这种情况下,将返回 NestedTensor,并且可以预期额外的加速与输入中 padding 的比例成正比。- 参数
embed_dim – 模型的总维度。
num_heads – 并行注意力头的数量。请注意,
embed_dim将被分割到num_heads中(即每个头的维度为embed_dim // num_heads)。dropout –
attn_output_weights上的 Dropout 概率。默认值:0.0(无 dropout)。bias – 如果指定,则向输入/输出投影层添加偏置。默认值:
True。add_bias_kv – 如果指定,则向 key 和 value 序列在 dim=0 处添加偏置。默认值:
False。add_zero_attn – 如果指定,则向 key 和 value 序列在 dim=1 处添加新的零批次。默认值:
False。kdim – key 的总特征数。默认值:
None(使用kdim=embed_dim)。vdim – value 的总特征数。默认值:
None(使用vdim=embed_dim)。batch_first – 如果为
True,则输入和输出张量为 (batch, seq, feature)。默认值:False(seq, batch, feature)。
示例
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[源]#
使用 query、key 和 value 嵌入计算注意力输出。
支持用于 padding、mask 和注意力权重的可选参数。
- 参数
query (Tensor) – shape 为 的未批处理输入, 当
batch_first=False时,或 当batch_first=True时,其中 是目标序列长度, 是批次大小, 是查询嵌入维度embed_dim。查询与键值对进行比较以生成输出。更多细节请参阅“Attention Is All You Need”。key (Tensor) – shape 为 的未批处理输入, 当
batch_first=False时,或 当batch_first=True时,其中 是源序列长度, 是批次大小, 是键嵌入维度kdim。更多细节请参阅“Attention Is All You Need”。value (Tensor) – shape 为 的未批处理输入, 当
batch_first=False时,或 当batch_first=True时,其中 是源序列长度, 是批次大小, 是值嵌入维度vdim。更多细节请参阅“Attention Is All You Need”。key_padding_mask (Optional[Tensor]) – 如果指定,则为 shape 为 的掩码,指示
key中哪些元素在注意力计算中被忽略(即被视为“padding”)。对于未批处理的 query,shape 应为 。支持二进制和浮点掩码。对于二进制掩码,True值表示相应的key值在注意力计算中将被忽略。对于浮点掩码,它将直接加到相应的key值上。need_weights (bool) – 如果指定,则除了
attn_outputs之外,还返回attn_output_weights。将need_weights=False以使用优化的scaled_dot_product_attention并为 MHA 获得最佳性能。默认值:True。attn_mask (Optional[Tensor]) – 如果指定,则为 2D 或 3D 掩码,可防止注意力关注某些位置。必须为 shape 或 ,其中 是批次大小, 是目标序列长度, 是源序列长度。2D 掩码将广播到整个批次,而 3D 掩码则允许每个批次条目都有不同的掩码。支持二进制和浮点掩码。对于二进制掩码,
True值表示不允许关注相应的位置。对于浮点掩码,掩码值将加到注意力权重上。如果同时提供了 attn_mask 和 key_padding_mask,它们的类型应匹配。average_attn_weights (bool) – 如果为 true,表示返回的
attn_weights应在各头之间取平均值。否则,attn_weights按头分开提供。请注意,此标志仅在need_weights=True时生效。默认值:True(即平均各头权重)。is_causal (bool) – 如果指定,则将因果掩码作为注意力掩码应用。默认值:
False。警告:is_causal提供了一个提示,即attn_mask是因果掩码。提供错误的提示可能导致执行错误,包括向前和向后兼容性问题。
- 返回类型
- 输出
attn_output - shape 为 的注意力输出,当输入未批处理时; 当
batch_first=False时;或 当batch_first=True时,其中 是目标序列长度, 是批次大小, 是嵌入维度embed_dim。attn_output_weights - 仅当
need_weights=True时返回。如果average_attn_weights=True,则返回平均后的注意力权重,shape 为 ,当输入未批处理时,或 ,当batch_first=False时,其中 是批次大小, 是目标序列长度, 是源序列长度。如果average_attn_weights=False,则返回各头的注意力权重,shape 为 ,当输入未批处理时,或 ,当batch_first=False时。
注意
batch_first 参数对于未批处理的输入将被忽略。
- merge_masks(attn_mask, key_padding_mask, query)[源]#
确定掩码类型并根据需要合并掩码。
如果只提供一个掩码,则返回该掩码和对应的掩码类型。如果同时提供两个掩码,它们将被扩展到 shape
(batch_size, num_heads, seq_len, seq_len),并通过逻辑or合并,并返回掩码类型 2::param attn_mask: attention mask,shape 为(seq_len, seq_len),掩码类型 0 :param key_padding_mask: padding mask,shape 为(batch_size, seq_len),掩码类型 1 :param query: query embeddings,shape 为(batch_size, seq_len, embed_dim)- 返回
merged mask mask_type: 合并后的掩码类型(0、1 或 2)。
- 返回类型
merged_mask