快捷方式

torchaudio.models.emformer_rnnt_model

torchaudio.models.emformer_rnnt_model(*, input_dim: int, encoding_dim: int, num_symbols: int, segment_length: int, right_context_length: int, time_reduction_input_dim: int, time_reduction_stride: int, transformer_num_heads: int, transformer_ffn_dim: int, transformer_num_layers: int, transformer_dropout: float, transformer_activation: str, transformer_left_context_length: int, transformer_max_memory_size: int, transformer_weight_init_scale_strategy: str, transformer_tanh_on_mem: bool, symbol_embedding_dim: int, num_lstm_layers: int, lstm_layer_norm: bool, lstm_layer_norm_epsilon: float, lstm_dropout: float) RNNT[源代码]

构建基于 Emformer 的 RNNT

注意

对于非流式推理,预期调用 transcribe 时输入的序列会与 right_context_length 帧进行右连接。

对于流式推理,预期调用 transcribe_streaming 时输入的块包含 segment_length 帧,并与 right_context_length 帧进行右连接。

参数
  • input_dim (int) – 传递给转录网络的输入序列帧的维度。

  • encoding_dim (int) – 传递给联合网络的转录网络和预测网络生成的编码的维度。

  • num_symbols (int) – 目标 token 集合的基数。

  • segment_length (int) – 以帧数为单位的输入片段的长度。

  • right_context_length (int) – 以帧数为单位的右侧上下文的长度。

  • time_reduction_input_dim (int) – 在应用时间缩减块之前,每个输入序列元素缩放到的维度。

  • time_reduction_stride (int) – 输入序列长度缩减的因子。

  • transformer_num_heads (int) – 每个 Emformer 层中的注意力头数。

  • transformer_ffn_dim (int) – 每个 Emformer 层的全连接网络的隐藏层维度。

  • transformer_num_layers (int) – 要实例化的 Emformer 层数。

  • transformer_left_context_length (int) – Emformer 考虑的左侧上下文的长度。

  • transformer_dropout (float) – Emformer 的 dropout 概率。

  • transformer_activation (str) – 在每个 Emformer 层的全连接网络中使用的激活函数。必须是 (“relu”, “gelu”, “silu”) 之一。

  • transformer_max_memory_size (int) – 要使用的最大内存元素数量。

  • transformer_weight_init_scale_strategy (str) – 按层权重初始化缩放策略。必须是 (“depthwise”, “constant”, None) 之一。

  • transformer_tanh_on_mem (bool) – 如果为 True,则将 tanh 应用于内存元素。

  • symbol_embedding_dim (int) – 每个目标 token 嵌入的维度。

  • num_lstm_layers (int) – 要实例化的 LSTM 层数。

  • lstm_layer_norm (bool) – 如果为 True,则为 LSTM 层启用层归一化。

  • lstm_layer_norm_epsilon (float) – 在 LSTM 层归一化层中使用的 epsilon 值。

  • lstm_dropout (float) – LSTM 的 dropout 概率。

返回

Emformer RNN-T 模型。

返回类型

RNNT

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源