评价此页

TransformerDecoder#

class torch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)[source]#

TransformerDecoder 是 N 个解码器层的堆栈。

这个 TransformerDecoder 层实现了 Attention Is All You Need 论文中描述的原始架构。此层的目的是作为基础理解的参考实现,因此与较新的 Transformer 架构相比,它只包含有限的功能。鉴于 Transformer 类架构的快速创新,我们建议探索此 教程,从核心构建块中构建高效层,或使用 PyTorch 生态系统 的高级库。

警告

TransformerDecoder 中的所有层都使用相同的参数进行初始化。建议在创建 TransformerDecoder 实例后手动初始化层。

参数
  • decoder_layer (TransformerDecoderLayer) – TransformerDecoderLayer() 类的实例(必需)。

  • num_layers (int) – 解码器中子解码器层的数量(必需)。

  • norm (Optional[Module]) – 层归一化组件(可选)。

示例

>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False)[source]#

依次通过解码器层传递输入(和掩码)。

参数
  • tgt (Tensor) – 解码器的序列(必需)。

  • memory (Tensor) – 来自编码器最后一层的序列(必需)。

  • tgt_mask (Optional[Tensor]) – tgt 序列的掩码(可选)。

  • memory_mask (Optional[Tensor]) – memory 序列的掩码(可选)。

  • tgt_key_padding_mask (Optional[Tensor]) – tgt 键的每批掩码(可选)。

  • memory_key_padding_mask (Optional[Tensor]) – memory 键的每批掩码(可选)。

  • tgt_is_causal (Optional[bool]) – 如果指定,则将因果掩码应用于 tgt mask。默认为 None;尝试检测因果掩码。警告:tgt_is_causal 提供了一个提示,即 tgt_mask 是因果掩码。提供不正确的提示可能导致执行错误,包括向前和向后兼容性问题。

  • memory_is_causal (bool) – 如果指定,则将因果掩码应用于 memory mask。默认为 False。警告:memory_is_causal 提供了一个提示,即 memory_mask 是因果掩码。提供不正确的提示可能导致执行错误,包括向前和向后兼容性问题。

返回类型

张量

形状

请参阅Transformer中的文档。