评价此页

TransformerEncoder#

class torch.nn.modules.transformer.TransformerEncoder(encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True)[源码]#

TransformerEncoder 是 N 个编码器层的堆栈。

此 TransformerEncoder 层实现了 Attention Is All You Need 论文中描述的原始架构。此层的目的是作为基础理解的参考实现,因此相比于较新的 Transformer 架构,它只包含有限的功能。鉴于 Transformer 类架构的快速创新步伐,我们建议探索此 教程,使用核心模块中的构建块来构建高效的层,或使用 PyTorch 生态系统 中的更高级库。

警告

TransformerEncoder 中的所有层都使用相同的参数进行初始化。建议在创建 TransformerEncoder 实例后手动初始化层。

参数
  • encoder_layer (TransformerEncoderLayer) – TransformerEncoderLayer() 类的实例(必需)。

  • num_layers (int) – 编码器中子编码器层的数量(必需)。

  • norm (Optional[Module]) – 层归一化组件(可选)。

  • enable_nested_tensor (bool) – 如果为 True,输入将自动转换为嵌套张量(并在输出时转换回来)。当填充率很高时,这将提高 TransformerEncoder 的整体性能。默认为 True(启用)。

示例

>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
forward(src, mask=None, src_key_padding_mask=None, is_causal=None)[源码]#

依次将输入通过编码器层。

参数
  • src (Tensor) – 到编码器的序列(必需)。

  • mask (Optional[Tensor]) – src 序列的掩码(可选)。

  • src_key_padding_mask (Optional[Tensor]) – src 键的每批掩码(可选)。

  • is_causal (Optional[bool]) – 如果指定,将 mask 应用为因果掩码。默认为 None;尝试检测因果掩码。警告:is_causal 提供了一个提示,表明 mask 是因果掩码。提供错误的提示可能导致执行错误,包括向前和向后兼容性。

返回类型

张量

形状

请参阅Transformer中的文档。