• 文档 >
  • CTC 强制对齐 API 教程 >
  • 旧版本 (稳定版)
快捷方式

CTC 强制对齐 API 教程

作者Xiaohui Zhang, Moto Hira

警告

从 2.8 版本开始,我们正在重构 TorchAudio,以使其进入维护阶段。因此:

  • 本教程中描述的 API 在 2.8 版本中已被弃用,并将在 2.9 版本中移除。

  • PyTorch 用于音频和视频的解码和编码功能正在被整合到 TorchCodec 中。

请参阅 https://github.com/pytorch/audio/issues/3902 获取更多信息。

强制对齐是将文字记录与语音对齐的过程。本教程展示了如何使用 torchaudio.functional.forced_align() 将文字记录与语音对齐,该功能是随着《将语音技术扩展到 1000 多种语言》的工作而开发的。

forced_align() 具有自定义的 CPU 和 CUDA 实现,它们比上面的普通 Python 实现性能更好,并且更准确。它还可以处理带有特殊 <star> 标记的缺失文字记录。

还有一个高级 API,torchaudio.pipelines.Wav2Vec2FABundle,它封装了本教程中解释的预处理/后处理,并使强制对齐的运行变得容易。多语言数据的强制对齐使用此 API 来演示如何对齐非英语文字记录。

准备

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)
2.8.0+cu126
2.8.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
import IPython
import matplotlib.pyplot as plt

import torchaudio.functional as F

首先,我们准备将要使用的语音数据和文字记录。

SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
waveform, _ = torchaudio.load(SPEECH_FILE)
TRANSCRIPT = "i had that curiosity beside me at this moment".split()
/pytorch/audio/examples/tutorials/ctc_forced_alignment_api_tutorial.py:65: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
/pytorch/audio/src/torchaudio/_backend/utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.ac.cn/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.
  warnings.warn(
/pytorch/audio/src/torchaudio/_backend/ffmpeg.py:88: UserWarning: torio.io._streaming_media_decoder.StreamingMediaDecoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. The decoding and encoding capabilities of PyTorch for both audio and video are being consolidated into TorchCodec. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)

生成发射

forced_align() 接收发射和标记序列,并输出标记的时间戳及其分数。

发射表示帧级标记的概率分布,可以通过将波形传递给声学模型来获得。

标记是文字记录的数字表达。标记文字记录有许多方法,但在这里,我们只是将字母映射为整数,这是我们即将使用的声学模型训练时标签的构建方式。

我们将使用预训练的 Wav2Vec2 模型,torchaudio.pipelines.MMS_FA,以获取发射并对文字记录进行标记化。

bundle = torchaudio.pipelines.MMS_FA

model = bundle.get_model(with_star=False).to(device)
with torch.inference_mode():
    emission, _ = model(waveform.to(device))
Downloading: "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt" to /root/.cache/torch/hub/checkpoints/model.pt

  0%|          | 0.00/1.18G [00:00<?, ?B/s]
  2%|2         | 26.5M/1.18G [00:00<00:04, 277MB/s]
  4%|4         | 53.0M/1.18G [00:00<00:04, 251MB/s]
  7%|6         | 78.6M/1.18G [00:00<00:04, 259MB/s]
  9%|8         | 104M/1.18G [00:00<00:04, 248MB/s]
 11%|#         | 130M/1.18G [00:00<00:04, 257MB/s]
 13%|#3        | 159M/1.18G [00:00<00:04, 271MB/s]
 15%|#5        | 185M/1.18G [00:00<00:04, 245MB/s]
 17%|#7        | 208M/1.18G [00:00<00:04, 232MB/s]
 19%|#9        | 231M/1.18G [00:01<00:04, 212MB/s]
 21%|##        | 252M/1.18G [00:01<00:04, 205MB/s]
 23%|##2       | 272M/1.18G [00:01<00:04, 200MB/s]
 25%|##4       | 298M/1.18G [00:01<00:04, 221MB/s]
 27%|##7       | 326M/1.18G [00:01<00:03, 240MB/s]
 29%|##9       | 353M/1.18G [00:01<00:03, 254MB/s]
 32%|###1      | 384M/1.18G [00:01<00:03, 272MB/s]
 34%|###4      | 411M/1.18G [00:01<00:02, 277MB/s]
 36%|###6      | 439M/1.18G [00:01<00:02, 280MB/s]
 39%|###8      | 467M/1.18G [00:01<00:02, 284MB/s]
 41%|####1     | 496M/1.18G [00:02<00:02, 289MB/s]
 43%|####3     | 524M/1.18G [00:02<00:02, 290MB/s]
 46%|####5     | 552M/1.18G [00:02<00:02, 294MB/s]
 48%|####8     | 580M/1.18G [00:02<00:02, 292MB/s]
 51%|#####     | 612M/1.18G [00:02<00:02, 304MB/s]
 54%|#####3    | 645M/1.18G [00:02<00:01, 315MB/s]
 56%|#####6    | 675M/1.18G [00:02<00:01, 315MB/s]
 59%|#####8    | 708M/1.18G [00:02<00:01, 325MB/s]
 61%|######1   | 739M/1.18G [00:02<00:01, 323MB/s]
 64%|######3   | 770M/1.18G [00:02<00:01, 293MB/s]
 66%|######6   | 799M/1.18G [00:03<00:01, 275MB/s]
 69%|######8   | 828M/1.18G [00:03<00:01, 282MB/s]
 71%|#######1  | 855M/1.18G [00:03<00:01, 275MB/s]
 73%|#######3  | 882M/1.18G [00:03<00:01, 279MB/s]
 76%|#######5  | 909M/1.18G [00:03<00:01, 271MB/s]
 78%|#######7  | 938M/1.18G [00:03<00:01, 278MB/s]
 80%|########  | 968M/1.18G [00:03<00:00, 290MB/s]
 83%|########3 | 999M/1.18G [00:03<00:00, 300MB/s]
 85%|########5 | 1.00G/1.18G [00:03<00:00, 292MB/s]
 88%|########8 | 1.03G/1.18G [00:04<00:00, 302MB/s]
 91%|######### | 1.07G/1.18G [00:04<00:00, 309MB/s]
 93%|#########3| 1.10G/1.18G [00:04<00:00, 315MB/s]
 96%|#########5| 1.13G/1.18G [00:04<00:00, 286MB/s]
 98%|#########8| 1.15G/1.18G [00:04<00:00, 271MB/s]
100%|##########| 1.18G/1.18G [00:04<00:00, 274MB/s]
def plot_emission(emission):
    fig, ax = plt.subplots()
    ax.imshow(emission.cpu().T)
    ax.set_title("Frame-wise class probabilities")
    ax.set_xlabel("Time")
    ax.set_ylabel("Labels")
    fig.tight_layout()


plot_emission(emission[0])
Frame-wise class probabilities

标记化文字记录

我们创建一个字典,将每个标签映射到标记。

LABELS = bundle.get_labels(star=None)
DICTIONARY = bundle.get_dict(star=None)
for k, v in DICTIONARY.items():
    print(f"{k}: {v}")
-: 0
a: 1
i: 2
e: 3
n: 4
o: 5
u: 6
t: 7
s: 8
r: 9
m: 10
k: 11
l: 12
d: 13
g: 14
h: 15
y: 16
b: 17
p: 18
w: 19
c: 20
v: 21
j: 22
z: 23
f: 24
': 25
q: 26
x: 27

将文字记录转换为标记很简单,只需

tokenized_transcript = [DICTIONARY[c] for word in TRANSCRIPT for c in word]

for t in tokenized_transcript:
    print(t, end=" ")
print()
2 15 1 13 7 15 1 7 20 6 9 2 5 8 2 7 16 17 3 8 2 13 3 10 3 1 7 7 15 2 8 10 5 10 3 4 7

计算对齐

帧级对齐

现在我们调用 TorchAudio 的强制对齐 API 来计算帧级对齐。有关函数签名的详细信息,请参阅 forced_align()

def align(emission, tokens):
    targets = torch.tensor([tokens], dtype=torch.int32, device=device)
    alignments, scores = F.forced_align(emission, targets, blank=0)

    alignments, scores = alignments[0], scores[0]  # remove batch dimension for simplicity
    scores = scores.exp()  # convert back to probability
    return alignments, scores


aligned_tokens, alignment_scores = align(emission, tokenized_transcript)
/pytorch/audio/examples/tutorials/ctc_forced_alignment_api_tutorial.py:146: UserWarning: torchaudio.functional._alignment.forced_align has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  alignments, scores = F.forced_align(emission, targets, blank=0)

现在让我们看看输出。

for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
    print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}")
  0:     0 [-], 1.00
  1:     0 [-], 1.00
  2:     0 [-], 1.00
  3:     0 [-], 1.00
  4:     0 [-], 1.00
  5:     0 [-], 1.00
  6:     0 [-], 1.00
  7:     0 [-], 1.00
  8:     0 [-], 1.00
  9:     0 [-], 1.00
 10:     0 [-], 1.00
 11:     0 [-], 1.00
 12:     0 [-], 1.00
 13:     0 [-], 1.00
 14:     0 [-], 1.00
 15:     0 [-], 1.00
 16:     0 [-], 1.00
 17:     0 [-], 1.00
 18:     0 [-], 1.00
 19:     0 [-], 1.00
 20:     0 [-], 1.00
 21:     0 [-], 1.00
 22:     0 [-], 1.00
 23:     0 [-], 1.00
 24:     0 [-], 1.00
 25:     0 [-], 1.00
 26:     0 [-], 1.00
 27:     0 [-], 1.00
 28:     0 [-], 1.00
 29:     0 [-], 1.00
 30:     0 [-], 1.00
 31:     0 [-], 1.00
 32:     2 [i], 1.00
 33:     0 [-], 1.00
 34:     0 [-], 1.00
 35:    15 [h], 1.00
 36:    15 [h], 0.93
 37:     1 [a], 1.00
 38:     0 [-], 0.96
 39:     0 [-], 1.00
 40:     0 [-], 1.00
 41:    13 [d], 1.00
 42:     0 [-], 1.00
 43:     0 [-], 0.97
 44:     7 [t], 1.00
 45:    15 [h], 1.00
 46:     0 [-], 0.98
 47:     1 [a], 1.00
 48:     0 [-], 1.00
 49:     0 [-], 1.00
 50:     7 [t], 1.00
 51:     0 [-], 1.00
 52:     0 [-], 1.00
 53:     0 [-], 1.00
 54:    20 [c], 1.00
 55:     0 [-], 1.00
 56:     0 [-], 1.00
 57:     0 [-], 1.00
 58:     6 [u], 1.00
 59:     6 [u], 0.96
 60:     0 [-], 1.00
 61:     0 [-], 1.00
 62:     0 [-], 0.53
 63:     9 [r], 1.00
 64:     0 [-], 1.00
 65:     2 [i], 1.00
 66:     0 [-], 1.00
 67:     0 [-], 1.00
 68:     0 [-], 1.00
 69:     0 [-], 1.00
 70:     0 [-], 1.00
 71:     0 [-], 0.96
 72:     5 [o], 1.00
 73:     0 [-], 1.00
 74:     0 [-], 1.00
 75:     0 [-], 1.00
 76:     0 [-], 1.00
 77:     0 [-], 1.00
 78:     0 [-], 1.00
 79:     8 [s], 1.00
 80:     0 [-], 1.00
 81:     0 [-], 1.00
 82:     0 [-], 0.99
 83:     2 [i], 1.00
 84:     0 [-], 1.00
 85:     7 [t], 1.00
 86:     0 [-], 1.00
 87:     0 [-], 1.00
 88:    16 [y], 1.00
 89:     0 [-], 1.00
 90:     0 [-], 1.00
 91:     0 [-], 1.00
 92:     0 [-], 1.00
 93:    17 [b], 1.00
 94:     0 [-], 1.00
 95:     3 [e], 1.00
 96:     0 [-], 1.00
 97:     0 [-], 1.00
 98:     0 [-], 1.00
 99:     0 [-], 1.00
100:     0 [-], 1.00
101:     8 [s], 1.00
102:     0 [-], 1.00
103:     0 [-], 1.00
104:     0 [-], 1.00
105:     0 [-], 1.00
106:     0 [-], 1.00
107:     0 [-], 1.00
108:     0 [-], 1.00
109:     0 [-], 0.64
110:     2 [i], 1.00
111:     0 [-], 1.00
112:     0 [-], 1.00
113:    13 [d], 1.00
114:     3 [e], 0.85
115:     0 [-], 1.00
116:    10 [m], 1.00
117:     0 [-], 1.00
118:     0 [-], 1.00
119:     3 [e], 1.00
120:     0 [-], 1.00
121:     0 [-], 1.00
122:     0 [-], 1.00
123:     0 [-], 1.00
124:     1 [a], 1.00
125:     0 [-], 1.00
126:     0 [-], 1.00
127:     7 [t], 1.00
128:     0 [-], 1.00
129:     7 [t], 1.00
130:    15 [h], 1.00
131:     0 [-], 0.79
132:     2 [i], 1.00
133:     0 [-], 1.00
134:     0 [-], 1.00
135:     0 [-], 1.00
136:     8 [s], 1.00
137:     0 [-], 1.00
138:     0 [-], 1.00
139:     0 [-], 1.00
140:     0 [-], 1.00
141:    10 [m], 1.00
142:     0 [-], 1.00
143:     0 [-], 1.00
144:     5 [o], 1.00
145:     0 [-], 1.00
146:     0 [-], 1.00
147:     0 [-], 1.00
148:    10 [m], 1.00
149:     0 [-], 1.00
150:     0 [-], 1.00
151:     3 [e], 1.00
152:     0 [-], 1.00
153:     4 [n], 1.00
154:     0 [-], 1.00
155:     7 [t], 1.00
156:     0 [-], 1.00
157:     0 [-], 1.00
158:     0 [-], 1.00
159:     0 [-], 1.00
160:     0 [-], 1.00
161:     0 [-], 1.00
162:     0 [-], 1.00
163:     0 [-], 1.00
164:     0 [-], 1.00
165:     0 [-], 1.00
166:     0 [-], 1.00
167:     0 [-], 1.00
168:     0 [-], 1.00

注意

对齐以发射的帧坐标表示,这与原始波形不同。

它包含空白标记和重复标记。以下是对非空白标记的解释。

31:     0 [-], 1.00
32:     2 [i], 1.00  "i" starts and ends
33:     0 [-], 1.00
34:     0 [-], 1.00
35:    15 [h], 1.00  "h" starts
36:    15 [h], 0.93  "h" ends
37:     1 [a], 1.00  "a" starts and ends
38:     0 [-], 0.96
39:     0 [-], 1.00
40:     0 [-], 1.00
41:    13 [d], 1.00  "d" starts and ends
42:     0 [-], 1.00

注意

当空白标记之后出现相同的标记时,它不被视为重复,而是被视为新的出现。

a a a b -> a b
a - - b -> a b
a a - b -> a b
a - a b -> a a b
  ^^^       ^^^

标记级对齐

下一步是解决重复,以便每个对齐不依赖于先前的对齐。torchaudio.functional.merge_tokens() 计算 TokenSpan 对象,该对象表示文字记录中的哪个标记在什么时间段内出现。

token_spans = F.merge_tokens(aligned_tokens, alignment_scores)

print("Token\tTime\tScore")
for s in token_spans:
    print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
Token   Time    Score
i       [ 32,  33)      1.00
h       [ 35,  37)      0.96
a       [ 37,  38)      1.00
d       [ 41,  42)      1.00
t       [ 44,  45)      1.00
h       [ 45,  46)      1.00
a       [ 47,  48)      1.00
t       [ 50,  51)      1.00
c       [ 54,  55)      1.00
u       [ 58,  60)      0.98
r       [ 63,  64)      1.00
i       [ 65,  66)      1.00
o       [ 72,  73)      1.00
s       [ 79,  80)      1.00
i       [ 83,  84)      1.00
t       [ 85,  86)      1.00
y       [ 88,  89)      1.00
b       [ 93,  94)      1.00
e       [ 95,  96)      1.00
s       [101, 102)      1.00
i       [110, 111)      1.00
d       [113, 114)      1.00
e       [114, 115)      0.85
m       [116, 117)      1.00
e       [119, 120)      1.00
a       [124, 125)      1.00
t       [127, 128)      1.00
t       [129, 130)      1.00
h       [130, 131)      1.00
i       [132, 133)      1.00
s       [136, 137)      1.00
m       [141, 142)      1.00
o       [144, 145)      1.00
m       [148, 149)      1.00
e       [151, 152)      1.00
n       [153, 154)      1.00
t       [155, 156)      1.00

单词级对齐

现在我们将标记级对齐分组为单词级对齐。

def unflatten(list_, lengths):
    assert len(list_) == sum(lengths)
    i = 0
    ret = []
    for l in lengths:
        ret.append(list_[i : i + l])
        i += l
    return ret


word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT])

音频预览

# Compute average score weighted by the span length
def _score(spans):
    return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)


def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate):
    ratio = waveform.size(1) / num_frames
    x0 = int(ratio * spans[0].start)
    x1 = int(ratio * spans[-1].end)
    print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
    segment = waveform[:, x0:x1]
    return IPython.display.Audio(segment.numpy(), rate=sample_rate)


num_frames = emission.size(1)
# Generate the audio for each segment
print(TRANSCRIPT)
IPython.display.Audio(SPEECH_FILE)
['i', 'had', 'that', 'curiosity', 'beside', 'me', 'at', 'this', 'moment']


preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT[0])
i (1.00): 0.644 - 0.664 sec


preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT[1])
had (0.98): 0.704 - 0.845 sec


preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT[2])
that (1.00): 0.885 - 1.026 sec


preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT[3])
curiosity (1.00): 1.086 - 1.790 sec


preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT[4])
beside (0.97): 1.871 - 2.314 sec


preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT[5])
me (1.00): 2.334 - 2.414 sec


preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT[6])
at (1.00): 2.495 - 2.575 sec


preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7])
this (1.00): 2.595 - 2.756 sec


preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8])
moment (1.00): 2.837 - 3.138 sec


可视化

现在让我们查看对齐结果,并将原始语音分割成单词。

def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate):
    ratio = waveform.size(1) / emission.size(1) / sample_rate

    fig, axes = plt.subplots(2, 1)
    axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
    axes[0].set_title("Emission")
    axes[0].set_xticks([])

    axes[1].specgram(waveform[0], Fs=sample_rate)
    for t_spans, chars in zip(token_spans, transcript):
        t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1
        axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white")
        axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
        axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)

        for span, char in zip(t_spans, chars):
            t0 = span.start * ratio
            axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False)

    axes[1].set_xlabel("time [second]")
    axes[1].set_xlim([0, None])
    fig.tight_layout()
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
Emission

空白 标记处理不一致

当将标记级对齐拆分为单词时,您会注意到某些空白标记的处理方式不同,这使得结果的解释有些模糊。

当我们绘制分数时,这很容易看出。下图显示了单词区域和非单词区域,以及非空白标记的帧级分数。

def plot_scores(word_spans, scores):
    fig, ax = plt.subplots()
    span_xs, span_hs = [], []
    ax.axvspan(word_spans[0][0].start - 0.05, word_spans[-1][-1].end + 0.05, facecolor="paleturquoise", edgecolor="none", zorder=-1)
    for t_span in word_spans:
        for span in t_span:
            for t in range(span.start, span.end):
                span_xs.append(t + 0.5)
                span_hs.append(scores[t].item())
            ax.annotate(LABELS[span.token], (span.start, -0.07))
        ax.axvspan(t_span[0].start - 0.05, t_span[-1].end + 0.05, facecolor="mistyrose", edgecolor="none", zorder=-1)
    ax.bar(span_xs, span_hs, color="lightsalmon", edgecolor="coral")
    ax.set_title("Frame-level scores and word segments")
    ax.set_ylim(-0.1, None)
    ax.grid(True, axis="y")
    ax.axhline(0, color="black")
    fig.tight_layout()


plot_scores(word_spans, alignment_scores)
Frame-level scores and word segments

在此图中,空白标记是那些没有垂直条的突出显示区域。您可以看到有些空白标记被解释为单词的一部分(红色突出显示),而另一些(蓝色突出显示)则不是。

造成这种情况的一个原因是模型在训练时没有单词边界的标签。空白标记不仅被视为重复,还被视为单词之间的静音。

但问题随之而来。单词末尾或附近紧随其后的帧应该是静音还是重复?

在上面的示例中,如果您返回到之前的语谱图和单词区域图,您会看到在“curiosity”中的“y”之后,仍然有多个频段的一些活动。

如果该帧包含在单词中,是否会更准确?

不幸的是,CTC 无法为此提供全面的解决方案。已知使用 CTC 训练的模型表现出“尖峰”响应,也就是说,它们倾向于在标签出现时出现尖峰,但尖峰不会持续标签的整个持续时间。(注意:预训练的 Wav2Vec2 模型倾向于在标签出现时出现尖峰,但这并非总是如此。)

[Zeyer 等,2021] 对 CTC 的尖峰行为进行了深入分析。我们鼓励有兴趣了解更多的人参考这篇论文。以下是论文中的一段引用,这正是我们在这里面临的问题。

尖峰行为在某些情况下可能存在问题, 例如当应用程序要求不使用空白标签时, 例如为了获得有意义的音素准确时间对齐 到转录。

高级:处理带 <star> 标记的文字记录

现在让我们看看当文字记录部分缺失时,如何使用能够模拟任何标记的 <star> 标记来提高对齐质量。

这里我们使用与上面相同的英语示例。但是我们从文字记录中删除了开头文本 “i had that curiosity beside me at”。将音频与此类文字记录对齐会导致现有单词“this”的错误对齐。但是,通过使用 <star> 标记来建模缺失的文本,可以缓解此问题。

首先,我们扩展字典以包含 <star> 标记。

DICTIONARY["*"] = len(DICTIONARY)

接下来,我们扩展发射张量,增加与 <star> 标记对应的额外维度。

star_dim = torch.zeros((1, emission.size(1), 1), device=emission.device, dtype=emission.dtype)
emission = torch.cat((emission, star_dim), 2)

assert len(DICTIONARY) == emission.shape[2]

plot_emission(emission[0])
Frame-wise class probabilities

以下函数结合了所有过程,并一次性从发射中计算出单词段。

def compute_alignments(emission, transcript, dictionary):
    tokens = [dictionary[char] for word in transcript for char in word]
    alignment, scores = align(emission, tokens)
    token_spans = F.merge_tokens(alignment, scores)
    word_spans = unflatten(token_spans, [len(word) for word in transcript])
    return word_spans

完整文字记录

word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
Emission
/pytorch/audio/examples/tutorials/ctc_forced_alignment_api_tutorial.py:146: UserWarning: torchaudio.functional._alignment.forced_align has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  alignments, scores = F.forced_align(emission, targets, blank=0)

<star> 标记的部分文字记录

现在我们将文字记录的第一部分替换为 <star> 标记。

transcript = "* this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)
Emission
/pytorch/audio/examples/tutorials/ctc_forced_alignment_api_tutorial.py:146: UserWarning: torchaudio.functional._alignment.forced_align has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  alignments, scores = F.forced_align(emission, targets, blank=0)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
* (1.00): 0.000 - 2.595 sec


preview_word(waveform, word_spans[1], num_frames, transcript[1])
this (1.00): 2.595 - 2.756 sec


preview_word(waveform, word_spans[2], num_frames, transcript[2])
moment (1.00): 2.837 - 3.138 sec


不带 <star> 标记的部分文字记录

作为比较,以下对齐不使用 <star> 标记的部分文字记录。它演示了 <star> 标记在处理删除错误方面的作用。

transcript = "this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)
Emission
/pytorch/audio/examples/tutorials/ctc_forced_alignment_api_tutorial.py:146: UserWarning: torchaudio.functional._alignment.forced_align has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  alignments, scores = F.forced_align(emission, targets, blank=0)

结论

在本教程中,我们介绍了如何使用 torchaudio 的强制对齐 API 对齐和分段语音文件,并演示了一个高级用法:当存在转录错误时,引入 <star> 标记如何提高对齐准确性。

致谢

感谢 Vineel PratapZhaoheng Ni 开发并开源了强制对齐器 API。

脚本总运行时间: (0 分 7.698 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源