注意
点击此处下载完整示例代码
使用 Wav2Vec2 进行语音识别¶
作者:Moto Hira
本教程展示了如何使用 Wav2Vec2.0 [论文] 的预训练模型执行语音识别。
概述¶
语音识别过程如下所示。
从音频波形中提取声学特征
逐帧估计声学特征的类别
从类别概率序列生成假设
Torchaudio 提供了对预训练权重和相关信息(例如预期采样率和类别标签)的轻松访问。它们捆绑在一起,可在 torchaudio.pipelines
模块下获得。
准备¶
import torch
import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
2.8.0+cu126
2.8.0
cuda
import IPython
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset
SPEECH_FILE = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
/pytorch/audio/examples/tutorials/speech_recognition_pipeline_tutorial.py:56: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
SPEECH_FILE = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
0%| | 0.00/106k [00:00<?, ?B/s]
100%|##########| 106k/106k [00:00<00:00, 37.4MB/s]
创建流水线¶
首先,我们将创建一个 Wav2Vec2 模型,该模型执行特征提取和分类。
torchaudio 中提供两种类型的 Wav2Vec2 预训练权重。一种是针对 ASR 任务微调的,另一种是未微调的。
Wav2Vec2(和 HuBERT)模型以自监督方式进行训练。它们首先仅使用音频进行表示学习,然后使用额外的标签针对特定任务进行微调。
未微调的预训练权重也可以针对其他下游任务进行微调,但本教程不涉及此内容。
我们将在此处使用 torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
。
torchaudio.pipelines
中提供了多个预训练模型。请查看文档以获取它们的训练详情。
捆绑包对象提供了实例化模型和其他信息的接口。采样率和类别标签如下所示。
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
print("Sample Rate:", bundle.sample_rate)
print("Labels:", bundle.get_labels())
Sample Rate: 16000
Labels: ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
模型可以按如下方式构建。此过程将自动获取预训练权重并将其加载到模型中。
model = bundle.get_model().to(device)
print(model.__class__)
Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth
0%| | 0.00/360M [00:00<?, ?B/s]
14%|#3 | 48.6M/360M [00:00<00:00, 507MB/s]
27%|##6 | 97.0M/360M [00:00<00:00, 469MB/s]
40%|#### | 145M/360M [00:00<00:00, 483MB/s]
53%|#####3 | 191M/360M [00:00<00:00, 419MB/s]
64%|######4 | 232M/360M [00:00<00:00, 398MB/s]
75%|#######5 | 271M/360M [00:00<00:00, 390MB/s]
86%|########6 | 310M/360M [00:00<00:00, 397MB/s]
97%|#########6| 349M/360M [00:00<00:00, 379MB/s]
100%|##########| 360M/360M [00:00<00:00, 407MB/s]
<class 'torchaudio.models.wav2vec2.model.Wav2Vec2Model'>
加载数据¶
我们将使用来自 VOiCES 数据集的语音数据,该数据集根据 Creative Commons BY 4.0 许可。
IPython.display.Audio(SPEECH_FILE)
要加载数据,我们使用 torchaudio.load()
。
如果采样率与流水线预期的不同,那么我们可以使用 torchaudio.functional.resample()
进行重采样。
注意
torchaudio.functional.resample()
也可以在 CUDA 张量上工作。在同一组采样率上多次执行重采样时,使用
torchaudio.transforms.Resample
可能会提高性能。
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
/pytorch/audio/src/torchaudio/_backend/utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.ac.cn/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.
warnings.warn(
/pytorch/audio/src/torchaudio/_backend/ffmpeg.py:88: UserWarning: torio.io._streaming_media_decoder.StreamingMediaDecoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. The decoding and encoding capabilities of PyTorch for both audio and video are being consolidated into TorchCodec. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
s = torchaudio.io.StreamReader(src, format, None, buffer_size)
提取声学特征¶
下一步是从音频中提取声学特征。
注意
针对 ASR 任务微调的 Wav2Vec2 模型可以通过一步执行特征提取和分类,但为了本教程的目的,我们还展示了如何在此处执行特征提取。
with torch.inference_mode():
features, _ = model.extract_features(waveform)
返回的特征是张量列表。每个张量都是变压器层的输出。
fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))
for i, feats in enumerate(features):
ax[i].imshow(feats[0].cpu(), interpolation="nearest")
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
fig.tight_layout()

特征分类¶
提取声学特征后,下一步是将其分类为一组类别。
Wav2Vec2 模型提供了一步执行特征提取和分类的方法。
with torch.inference_mode():
emission, _ = model(waveform)
输出是 logits 形式。它不是概率形式。
让我们可视化一下。
plt.imshow(emission[0].cpu().T, interpolation="nearest")
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.tight_layout()
print("Class labels:", bundle.get_labels())

Class labels: ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
我们可以看到,在时间线上,对某些标签有强烈的指示。
生成转录¶
从标签概率序列中,我们现在想生成转录。生成假设的过程通常称为“解码”。
解码比简单的分类更复杂,因为在某个时间步的解码可能会受到周围观测值的影响。
例如,以“night”和“knight”这样的词为例。即使它们的先验概率分布不同(在典型的对话中,“night”出现的频率远高于“knight”),为了准确地生成包含“knight”的转录,例如“a knight with a sword”,解码过程必须推迟最终决定,直到看到足够的上下文。
提出了许多解码技术,它们需要外部资源,例如词典和语言模型。
在本教程中,为了简单起见,我们将执行贪婪解码,它不依赖于这些外部组件,并且简单地在每个时间步选择最佳假设。因此,不使用上下文信息,并且只能生成一个转录。
我们首先定义贪婪解码算法。
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> str:
"""Given a sequence emission over labels, get the best path string
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
return "".join([self.labels[i] for i in indices])
现在创建解码器对象并解码转录。
decoder = GreedyCTCDecoder(labels=bundle.get_labels())
transcript = decoder(emission[0])
让我们检查结果并再次听音频。
print(transcript)
IPython.display.Audio(SPEECH_FILE)
I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|
ASR 模型使用一种称为连接主义时间分类(CTC)的损失函数进行微调。CTC 损失的详细信息此处有解释。在 CTC 中,空白标记 (ϵ) 是一个特殊标记,表示重复前面的符号。在解码时,这些符号会被简单地忽略。
结论¶
在本教程中,我们了解了如何使用 Wav2Vec2ASRBundle
执行声学特征提取和语音识别。构建模型并获取排放只需要两行代码。
model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
emission = model(waveforms, ...)
脚本总运行时间: (0 分 4.907 秒)