注意
点击此处下载完整示例代码
带 CTC 解码器的 ASR 推理¶
作者: Caroline Chen
本教程展示了如何使用带词典约束和 KenLM 语言模型支持的 CTC 束搜索解码器执行语音识别推理。我们将在使用 CTC 损失训练的预训练 wav2vec 2.0 模型上演示这一点。
概述¶
束搜索解码通过迭代地用下一个可能的字符扩展文本假设(束),并且在每个时间步只保留得分最高的假设。语言模型可以整合到得分计算中,添加词典约束可以限制假设的下一个可能的 token,以便只生成词典中的单词。
底层实现是从 Flashlight 的束搜索解码器移植而来。解码器优化的数学公式可以在 Wav2Letter 论文中找到,更详细的算法可以在这篇博客中找到。
使用带语言模型和词典约束的 CTC 束搜索解码器运行 ASR 推理需要以下组件:
声学模型:从音频波形预测语音的模型
Token:声学模型可能的预测 token
词典:可能单词与其对应 token 序列之间的映射
语言模型(LM):使用 KenLM 库训练的 n-gram 语言模型,或者继承
CTCDecoderLM
的自定义语言模型
声学模型和设置¶
首先,我们导入必要的实用程序并获取我们要处理的数据
import torch
import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
2.8.0+cu126
2.8.0
import time
from typing import List
import IPython
import matplotlib.pyplot as plt
from torchaudio.models.decoder import ctc_decoder
from torchaudio.utils import download_asset
我们使用预训练的 Wav2Vec 2.0 Base 模型,该模型在 LibriSpeech 数据集的 10 分钟数据上进行微调,可以使用 torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M
加载。有关在 torchaudio 中运行 Wav2Vec 2.0 语音识别管道的更多详细信息,请参阅本教程。
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M
acoustic_model = bundle.get_model()
Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ll10m.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ll10m.pth
0%| | 0.00/360M [00:00<?, ?B/s]
15%|#4 | 52.9M/360M [00:00<00:00, 553MB/s]
29%|##9 | 106M/360M [00:00<00:00, 540MB/s]
44%|####3 | 157M/360M [00:00<00:00, 497MB/s]
57%|#####6 | 205M/360M [00:00<00:00, 493MB/s]
70%|####### | 252M/360M [00:00<00:00, 430MB/s]
83%|########2 | 299M/360M [00:00<00:00, 447MB/s]
96%|#########6| 346M/360M [00:00<00:00, 461MB/s]
100%|##########| 360M/360M [00:00<00:00, 469MB/s]
我们将从 LibriSpeech test-other 数据集中加载一个样本。
speech_file = download_asset("tutorial-assets/ctc-decoding/1688-142285-0007.wav")
IPython.display.Audio(speech_file)
/pytorch/audio/examples/tutorials/asr_inference_with_ctc_decoder_tutorial.py:88: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
speech_file = download_asset("tutorial-assets/ctc-decoding/1688-142285-0007.wav")
此音频文件对应的转录是
waveform, sample_rate = torchaudio.load(speech_file)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
/pytorch/audio/src/torchaudio/_backend/utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.ac.cn/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.
warnings.warn(
/pytorch/audio/src/torchaudio/_backend/ffmpeg.py:88: UserWarning: torio.io._streaming_media_decoder.StreamingMediaDecoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. The decoding and encoding capabilities of PyTorch for both audio and video are being consolidated into TorchCodec. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
s = torchaudio.io.StreamReader(src, format, None, buffer_size)
解码器文件和数据¶
接下来,我们加载 token、词典和语言模型数据,解码器使用这些数据从声学模型输出中预测单词。LibriSpeech 数据集的预训练文件可以通过 torchaudio 下载,用户也可以提供自己的文件。
Token¶
Token 是声学模型可以预测的可能符号,包括空白和静音符号。它可以作为文件传递,其中每行包含对应于相同索引的 token,或者作为 token 列表传递,每个 token 映射到一个唯一的索引。
# tokens.txt
_
|
e
t
...
['-', '|', 'e', 't', 'a', 'o', 'n', 'i', 'h', 's', 'r', 'd', 'l', 'u', 'm', 'w', 'c', 'f', 'g', 'y', 'p', 'b', 'v', 'k', "'", 'x', 'j', 'q', 'z']
词典¶
词典是单词与其对应 token 序列的映射,用于将解码器的搜索空间限制为仅来自词典的单词。词典文件的预期格式是每行一个单词,后跟其以空格分隔的 token。
# lexcion.txt
a a |
able a b l e |
about a b o u t |
...
...
语言模型¶
语言模型可用于解码以改善结果,通过将代表序列可能性的语言模型分数纳入束搜索计算中。下面,我们概述了支持解码的不同形式的语言模型。
无语言模型¶
要创建没有语言模型的解码器实例,请在初始化解码器时设置 lm=None。
KenLM¶
这是使用 KenLM 库训练的 n-gram 语言模型。可以使用 .arpa
或二进制 .bin
LM,但建议使用二进制格式以加快加载速度。
本教程中使用的语言模型是使用 LibriSpeech 训练的 4-gram KenLM。
自定义语言模型¶
用户可以使用 CTCDecoderLM
和 CTCDecoderLMState
在 Python 中定义自己的自定义语言模型,无论是统计语言模型还是神经网络语言模型。
例如,以下代码围绕 PyTorch torch.nn.Module
语言模型创建了一个基本包装器。
from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState
class CustomLM(CTCDecoderLM):
"""Create a Python wrapper around `language_model` to feed to the decoder."""
def __init__(self, language_model: torch.nn.Module):
CTCDecoderLM.__init__(self)
self.language_model = language_model
self.sil = -1 # index for silent token in the language model
self.states = {}
language_model.eval()
def start(self, start_with_nothing: bool = False):
state = CTCDecoderLMState()
with torch.no_grad():
score = self.language_model(self.sil)
self.states[state] = score
return state
def score(self, state: CTCDecoderLMState, token_index: int):
outstate = state.child(token_index)
if outstate not in self.states:
score = self.language_model(token_index)
self.states[outstate] = score
score = self.states[outstate]
return outstate, score
def finish(self, state: CTCDecoderLMState):
return self.score(state, self.sil)
下载预训练文件¶
可以使用 download_pretrained_files()
下载 LibriSpeech 数据集的预训练文件。
注意:此单元格可能需要几分钟才能运行,因为语言模型可能很大
from torchaudio.models.decoder import download_pretrained_files
files = download_pretrained_files("librispeech-4-gram")
print(files)
/pytorch/audio/src/torchaudio/models/decoder/_ctc_decoder.py:557: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
lexicon_file = download_asset(files.lexicon)
0%| | 0.00/4.97M [00:00<?, ?B/s]
100%|##########| 4.97M/4.97M [00:00<00:00, 316MB/s]
/pytorch/audio/src/torchaudio/models/decoder/_ctc_decoder.py:558: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
tokens_file = download_asset(files.tokens)
0%| | 0.00/57.0 [00:00<?, ?B/s]
100%|##########| 57.0/57.0 [00:00<00:00, 115kB/s]
/pytorch/audio/src/torchaudio/models/decoder/_ctc_decoder.py:560: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
lm_file = download_asset(files.lm)
0%| | 0.00/2.91G [00:00<?, ?B/s]
1%|1 | 39.8M/2.91G [00:00<00:07, 416MB/s]
3%|2 | 81.5M/2.91G [00:00<00:07, 428MB/s]
4%|4 | 122M/2.91G [00:00<00:07, 399MB/s]
5%|5 | 161M/2.91G [00:00<00:09, 312MB/s]
7%|6 | 203M/2.91G [00:00<00:08, 352MB/s]
8%|8 | 251M/2.91G [00:00<00:07, 396MB/s]
10%|9 | 294M/2.91G [00:00<00:06, 411MB/s]
11%|#1 | 338M/2.91G [00:00<00:06, 426MB/s]
13%|#2 | 380M/2.91G [00:00<00:06, 430MB/s]
14%|#4 | 422M/2.91G [00:01<00:06, 434MB/s]
16%|#5 | 474M/2.91G [00:01<00:05, 465MB/s]
17%|#7 | 519M/2.91G [00:01<00:05, 468MB/s]
19%|#8 | 564M/2.91G [00:01<00:05, 468MB/s]
20%|## | 609M/2.91G [00:01<00:05, 465MB/s]
22%|##1 | 655M/2.91G [00:01<00:05, 471MB/s]
24%|##3 | 702M/2.91G [00:01<00:05, 478MB/s]
25%|##5 | 748M/2.91G [00:01<00:05, 440MB/s]
27%|##6 | 794M/2.91G [00:01<00:05, 452MB/s]
28%|##8 | 839M/2.91G [00:02<00:04, 456MB/s]
30%|##9 | 883M/2.91G [00:02<00:04, 448MB/s]
31%|###1 | 926M/2.91G [00:02<00:04, 442MB/s]
32%|###2 | 968M/2.91G [00:02<00:05, 411MB/s]
34%|###3 | 0.98G/2.91G [00:02<00:05, 398MB/s]
35%|###5 | 1.02G/2.91G [00:02<00:06, 336MB/s]
36%|###6 | 1.06G/2.91G [00:02<00:05, 354MB/s]
38%|###7 | 1.09G/2.91G [00:02<00:05, 358MB/s]
39%|###8 | 1.13G/2.91G [00:02<00:05, 368MB/s]
40%|#### | 1.17G/2.91G [00:03<00:04, 384MB/s]
42%|####1 | 1.21G/2.91G [00:03<00:04, 395MB/s]
43%|####2 | 1.25G/2.91G [00:03<00:04, 403MB/s]
44%|####4 | 1.29G/2.91G [00:03<00:04, 410MB/s]
46%|####5 | 1.33G/2.91G [00:03<00:04, 412MB/s]
47%|####6 | 1.37G/2.91G [00:03<00:04, 414MB/s]
48%|####8 | 1.41G/2.91G [00:03<00:03, 415MB/s]
50%|####9 | 1.45G/2.91G [00:03<00:03, 416MB/s]
51%|##### | 1.48G/2.91G [00:03<00:03, 398MB/s]
52%|#####2 | 1.52G/2.91G [00:03<00:03, 385MB/s]
54%|#####3 | 1.57G/2.91G [00:04<00:03, 413MB/s]
55%|#####5 | 1.61G/2.91G [00:04<00:03, 399MB/s]
57%|#####6 | 1.64G/2.91G [00:04<00:03, 406MB/s]
58%|#####7 | 1.68G/2.91G [00:04<00:03, 411MB/s]
59%|#####9 | 1.72G/2.91G [00:04<00:03, 416MB/s]
61%|###### | 1.76G/2.91G [00:04<00:03, 377MB/s]
62%|######1 | 1.80G/2.91G [00:04<00:03, 379MB/s]
63%|######3 | 1.84G/2.91G [00:04<00:02, 392MB/s]
65%|######4 | 1.88G/2.91G [00:04<00:02, 400MB/s]
66%|######6 | 1.92G/2.91G [00:05<00:02, 422MB/s]
68%|######7 | 1.97G/2.91G [00:05<00:02, 437MB/s]
69%|######9 | 2.01G/2.91G [00:05<00:02, 450MB/s]
71%|####### | 2.05G/2.91G [00:05<00:02, 440MB/s]
72%|#######1 | 2.09G/2.91G [00:05<00:01, 438MB/s]
73%|#######3 | 2.14G/2.91G [00:05<00:01, 423MB/s]
75%|#######5 | 2.19G/2.91G [00:05<00:01, 456MB/s]
77%|#######6 | 2.23G/2.91G [00:05<00:01, 476MB/s]
78%|#######8 | 2.28G/2.91G [00:05<00:01, 457MB/s]
80%|#######9 | 2.32G/2.91G [00:05<00:01, 450MB/s]
81%|########1 | 2.36G/2.91G [00:06<00:01, 452MB/s]
83%|########2 | 2.41G/2.91G [00:06<00:01, 457MB/s]
84%|########4 | 2.45G/2.91G [00:06<00:01, 336MB/s]
86%|########6 | 2.50G/2.91G [00:06<00:01, 392MB/s]
88%|########7 | 2.55G/2.91G [00:06<00:00, 416MB/s]
89%|########9 | 2.59G/2.91G [00:06<00:00, 431MB/s]
91%|######### | 2.64G/2.91G [00:06<00:00, 432MB/s]
92%|#########2| 2.68G/2.91G [00:06<00:00, 433MB/s]
94%|#########3| 2.72G/2.91G [00:07<00:00, 445MB/s]
95%|#########5| 2.77G/2.91G [00:07<00:00, 444MB/s]
97%|#########6| 2.81G/2.91G [00:07<00:00, 447MB/s]
98%|#########7| 2.85G/2.91G [00:07<00:00, 407MB/s]
99%|#########9| 2.89G/2.91G [00:07<00:00, 405MB/s]
100%|##########| 2.91G/2.91G [00:07<00:00, 416MB/s]
PretrainedFiles(lexicon='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lexicon.txt', tokens='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/tokens.txt', lm='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lm.bin')
构建解码器¶
在本教程中,我们构建了一个束搜索解码器和一个贪婪解码器进行比较。
束搜索解码器¶
解码器可以使用工厂函数 ctc_decoder()
构建。除了前面提到的组件之外,它还接受各种束搜索解码参数和 token/单词参数。
此解码器也可以在没有语言模型的情况下运行,方法是将 None 传递给 lm 参数。
LM_WEIGHT = 3.23
WORD_SCORE = -0.26
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
nbest=3,
beam_size=1500,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
贪婪解码器¶
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> List[str]:
"""Given a sequence emission over labels, get the best path
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
List[str]: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
joined = "".join([self.labels[i] for i in indices])
return joined.replace("|", " ").strip().split()
greedy_decoder = GreedyCTCDecoder(tokens)
运行推理¶
现在我们有了数据、声学模型和解码器,我们可以执行推理。束搜索解码器的输出类型为 CTCHypothesis
,由预测的 token ID、相应的单词(如果提供了词典)、假设分数和与 token ID 对应的时间步组成。回想一下波形对应的转录是
actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said"
actual_transcript = actual_transcript.split()
emission, _ = acoustic_model(waveform)
贪婪解码器给出以下结果。
greedy_result = greedy_decoder(emission[0])
greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)
print(f"Transcript: {greedy_transcript}")
print(f"WER: {greedy_wer}")
Transcript: i reily was very much affrayd of showing him howmuch shoktd i wause at some parte of what he seid
WER: 0.38095238095238093
使用束搜索解码器
beam_search_result = beam_search_decoder(emission)
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(
actual_transcript
)
print(f"Transcript: {beam_search_transcript}")
print(f"WER: {beam_search_wer}")
Transcript: i really was very much afraid of showing him how much shocked i was at some part of what he said
WER: 0.047619047619047616
注意
如果没有向解码器提供词典,则输出假设的 words
字段将为空。要检索无词典解码的转录,可以执行以下操作以检索 token 索引,将其转换为原始 token,然后将它们连接起来。
tokens_str = "".join(beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens))
transcript = " ".join(tokens_str.split("|"))
我们看到,带有词典约束的束搜索解码器生成的转录产生了更准确的结果,其中包含真实单词,而贪婪解码器可能会预测拼写错误的单词,如“affrayd”和“shoktd”。
增量解码¶
如果输入语音很长,可以增量方式解码发射。
您需要首先使用 decode_begin()
初始化解码器的内部状态。
beam_search_decoder.decode_begin()
然后,您可以将发射传递给 decode_begin()
。这里我们使用相同的发射,但一次一帧地将其传递给解码器。
最后,完成解码器的内部状态,并检索结果。
beam_search_decoder.decode_end()
beam_search_result_inc = beam_search_decoder.get_final_hypothesis()
增量解码的结果与批解码相同。
beam_search_transcript_inc = " ".join(beam_search_result_inc[0].words).strip()
beam_search_wer_inc = torchaudio.functional.edit_distance(
actual_transcript, beam_search_result_inc[0].words) / len(actual_transcript)
print(f"Transcript: {beam_search_transcript_inc}")
print(f"WER: {beam_search_wer_inc}")
assert beam_search_result[0][0].words == beam_search_result_inc[0].words
assert beam_search_result[0][0].score == beam_search_result_inc[0].score
torch.testing.assert_close(beam_search_result[0][0].timesteps, beam_search_result_inc[0].timesteps)
Transcript: i really was very much afraid of showing him how much shocked i was at some part of what he said
WER: 0.047619047619047616
时间步对齐¶
回想一下,结果假设的组件之一是与 token ID 对应的时间步。
timesteps = beam_search_result[0][0].timesteps
predicted_tokens = beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens)
print(predicted_tokens, len(predicted_tokens))
print(timesteps, timesteps.shape[0])
['|', 'i', '|', 'r', 'e', 'a', 'l', 'l', 'y', '|', 'w', 'a', 's', '|', 'v', 'e', 'r', 'y', '|', 'm', 'u', 'c', 'h', '|', 'a', 'f', 'r', 'a', 'i', 'd', '|', 'o', 'f', '|', 's', 'h', 'o', 'w', 'i', 'n', 'g', '|', 'h', 'i', 'm', '|', 'h', 'o', 'w', '|', 'm', 'u', 'c', 'h', '|', 's', 'h', 'o', 'c', 'k', 'e', 'd', '|', 'i', '|', 'w', 'a', 's', '|', 'a', 't', '|', 's', 'o', 'm', 'e', '|', 'p', 'a', 'r', 't', '|', 'o', 'f', '|', 'w', 'h', 'a', 't', '|', 'h', 'e', '|', 's', 'a', 'i', 'd', '|', '|'] 99
tensor([ 0, 31, 33, 36, 39, 41, 42, 44, 46, 48, 49, 52, 54, 58,
64, 66, 69, 73, 74, 76, 80, 82, 84, 86, 88, 94, 97, 107,
111, 112, 116, 134, 136, 138, 140, 142, 146, 148, 151, 153, 155, 157,
159, 161, 162, 166, 170, 176, 177, 178, 179, 182, 184, 186, 187, 191,
193, 198, 201, 202, 203, 205, 207, 212, 213, 216, 222, 224, 230, 250,
251, 254, 256, 261, 262, 264, 267, 270, 276, 277, 281, 284, 288, 289,
292, 295, 297, 299, 300, 303, 305, 307, 310, 311, 324, 325, 329, 331,
353], dtype=torch.int32) 99
下面,我们可视化 token 时间步对齐相对于原始波形。
def plot_alignments(waveform, emission, tokens, timesteps, sample_rate):
t = torch.arange(waveform.size(0)) / sample_rate
ratio = waveform.size(0) / emission.size(1) / sample_rate
chars = []
words = []
word_start = None
for token, timestep in zip(tokens, timesteps * ratio):
if token == "|":
if word_start is not None:
words.append((word_start, timestep))
word_start = None
else:
chars.append((token, timestep))
if word_start is None:
word_start = timestep
fig, axes = plt.subplots(3, 1)
def _plot(ax, xlim):
ax.plot(t, waveform)
for token, timestep in chars:
ax.annotate(token.upper(), (timestep, 0.5))
for word_start, word_end in words:
ax.axvspan(word_start, word_end, alpha=0.1, color="red")
ax.set_ylim(-0.6, 0.7)
ax.set_yticks([0])
ax.grid(True, axis="y")
ax.set_xlim(xlim)
_plot(axes[0], (0.3, 2.5))
_plot(axes[1], (2.5, 4.7))
_plot(axes[2], (4.7, 6.9))
axes[2].set_xlabel("time (sec)")
fig.tight_layout()
plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate)

束搜索解码器参数¶
在本节中,我们将更深入地讨论一些不同的参数和权衡。有关可自定义参数的完整列表,请参阅文档
。
辅助函数¶
def print_decoded(decoder, emission, param, param_value):
start_time = time.monotonic()
result = decoder(emission)
decode_time = time.monotonic() - start_time
transcript = " ".join(result[0][0].words).lower().strip()
score = result[0][0].score
print(f"{param} {param_value:<3}: {transcript} (score: {score:.2f}; {decode_time:.4f} secs)")
nbest¶
此参数指示要返回的最佳假设的数量,这是贪婪解码器无法实现的功能。例如,通过在构建束搜索解码器时将 nbest=3
设置为 3,我们现在可以访问得分前 3 的假设。
for i in range(3):
transcript = " ".join(beam_search_result[0][i].words).strip()
score = beam_search_result[0][i].score
print(f"{transcript} (score: {score})")
i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.824109642502)
i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: 3697.858373688456)
i reply was very much afraid of showing him how much shocked i was at some part of what he said (score: 3695.0157600045172)
束大小¶
beam_size
参数确定每个解码步骤后要保留的最佳假设的最大数量。使用更大的束大小可以探索更大范围的可能假设,这可以生成得分更高的假设,但计算成本更高,并且在达到某个点后不会提供额外收益。
在下面的示例中,我们看到随着束大小从 1 增加到 5 再到 50,解码质量有所提高,但请注意,使用 500 的束大小会产生与 50 的束大小相同的输出,同时增加计算时间。
beam_sizes = [1, 5, 50, 500]
for beam_size in beam_sizes:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size=beam_size,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size", beam_size)
beam size 1 : i you ery much afra of shongut shot i was at some arte what he sad (score: 3144.93; 0.0233 secs)
beam size 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3688.02; 0.0480 secs)
beam size 50 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.1666 secs)
beam size 500: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.5521 secs)
束大小 token¶
beam_size_token
参数对应于在解码步骤中用于扩展每个假设的 token 数量。探索更多可能的下一个 token 会增加潜在假设的范围,但会增加计算成本。
num_tokens = len(tokens)
beam_size_tokens = [1, 5, 10, num_tokens]
for beam_size_token in beam_size_tokens:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size_token=beam_size_token,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size token", beam_size_token)
beam size token 1 : i rely was very much affray of showing him hoch shot i was at some part of what he sed (score: 3584.80; 0.1605 secs)
beam size token 5 : i rely was very much afraid of showing him how much shocked i was at some part of what he said (score: 3694.83; 0.1809 secs)
beam size token 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3696.25; 0.2005 secs)
beam size token 29 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2326 secs)
束阈值¶
beam_threshold
参数用于在每个解码步骤修剪存储的假设集,删除得分高于最高得分假设 beam_threshold
的假设。在选择较小的阈值以修剪更多假设并减少搜索空间之间存在平衡,以及选择足够大的阈值以使合理假设不被修剪。
beam_thresholds = [1, 5, 10, 25]
for beam_threshold in beam_thresholds:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_threshold=beam_threshold,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam threshold", beam_threshold)
beam threshold 1 : i ila ery much afraid of shongut shot i was at some parts of what he said (score: 3316.20; 0.0287 secs)
beam threshold 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3682.23; 0.0508 secs)
beam threshold 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2170 secs)
beam threshold 25 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2371 secs)
语言模型权重¶
lm_weight
参数是分配给语言模型分数的权重,该分数与声学模型分数累加以确定总分数。较大的权重鼓励模型根据语言模型预测下一个单词,而较小的权重则更多地侧重于声学模型分数。
lm_weights = [0, LM_WEIGHT, 15]
for lm_weight in lm_weights:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
lm_weight=lm_weight,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "lm weight", lm_weight)
lm weight 0 : i rely was very much affraid of showing him ho much shoke i was at some parte of what he seid (score: 3834.05; 0.2575 secs)
lm weight 3.23: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2619 secs)
lm weight 15 : was there in his was at some of what he said (score: 2918.99; 0.2444 secs)
附加参数¶
可以优化的附加参数包括以下内容
word_score
:单词完成时添加的分数unk_score
:添加的未知单词出现分数sil_score
:添加的静音出现分数log_add
:是否使用 log add 进行词典 Trie 涂抹
脚本总运行时间: (1 分 59.323 秒)