• 文档 >
  • 使用 Hybrid Demucs 进行音乐源分离 >
  • 旧版本 (稳定版)
快捷方式

使用 Hybrid Demucs 进行音乐源分离

作者Sean Kim

本教程展示了如何使用 Hybrid Demucs 模型进行音乐分离

1. 概述

执行音乐分离包含以下步骤

  1. 构建 Hybrid Demucs 管道。

  2. 将波形格式化为预期大小的块,并循环处理这些块(带有重叠),然后输入到管道中。

  3. 收集输出块并根据它们的重叠方式进行组合。

Hybrid Demucs [Défossez, 2021] 模型是 Demucs 模型的开发版本,它是一个基于波形模型,将音乐分离成各自的源,如人声、贝斯和鼓。Hybrid Demucs 有效地使用频谱图通过频域学习,并转换为时间卷积。

2. 准备

首先,我们安装必要的依赖项。第一个要求是 torchaudiotorch

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

import matplotlib.pyplot as plt
2.8.0+cu126
2.8.0

除了 torchaudio,还需要 mir_eval 来执行信噪失真比 (SDR) 计算。要安装 mir_eval,请使用 pip3 install mir_eval

from IPython.display import Audio
from mir_eval import separation
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.utils import download_asset

3. 构建管道

预训练模型权重和相关管道组件捆绑为 torchaudio.pipelines.HDEMUCS_HIGH_MUSDB_PLUS()。这是一个在 torchaudio.models.HDemucs 模型,并在 MUSDB18-HQ 和额外的内部训练数据上进行了训练。此特定模型适用于较高的采样率,大约 44.1 kHz,并且在模型实现中具有 4096 的 nfft 值和 6 的深度。

bundle = HDEMUCS_HIGH_MUSDB_PLUS

model = bundle.get_model()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

sample_rate = bundle.sample_rate

print(f"Sample rate: {sample_rate}")
/pytorch/audio/src/torchaudio/pipelines/_source_separation_pipeline.py:55: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  path = torchaudio.utils.download_asset(self._model_path)

  0%|          | 0.00/319M [00:00<?, ?B/s]
 12%|#1        | 37.5M/319M [00:00<00:00, 393MB/s]
 24%|##4       | 77.0M/319M [00:00<00:00, 405MB/s]
 36%|###6      | 116M/319M [00:00<00:00, 374MB/s]
 48%|####7     | 152M/319M [00:00<00:00, 319MB/s]
 57%|#####7    | 183M/319M [00:00<00:00, 321MB/s]
 70%|#######   | 224M/319M [00:00<00:00, 352MB/s]
 81%|########  | 258M/319M [00:00<00:00, 345MB/s]
 92%|#########1| 293M/319M [00:00<00:00, 352MB/s]
100%|##########| 319M/319M [00:00<00:00, 347MB/s]
Sample rate: 44100

4. 配置应用函数

由于 HDemucs 是一个大型且内存消耗高的模型,因此很难有足够的内存一次性将模型应用于整首歌曲。为了解决这个限制,可以通过将歌曲分块成更小的片段,并逐段运行模型,然后重新组合起来,来获取完整歌曲的分离源。

这样做时,重要的是确保每个块之间有一些重叠,以适应边缘处的伪影。由于模型的性质,有时边缘会包含不准确或不希望的声音。

我们提供了一个分块和排列的示例实现如下。此实现采用每侧 1 秒的重叠,然后在每侧进行线性淡入和淡出。使用淡入淡出的重叠,我将这些片段加在一起,以确保整个音量恒定。这通过减少模型输出的边缘部分来适应伪影。

https://download.pytorch.org/torchaudio/tutorial-assets/HDemucs_Drawing.jpg
from torchaudio.transforms import Fade


def separate_sources(
    model,
    mix,
    segment=10.0,
    overlap=0.1,
    device=None,
):
    """
    Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.

    Args:
        segment (int): segment length in seconds
        device (torch.device, str, or None): if provided, device on which to
            execute the computation, otherwise `mix.device` is assumed.
            When `device` is different from `mix.device`, only local computations will
            be on `device`, while the entire tracks will be stored on `mix.device`.
    """
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")

    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model.forward(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final


def plot_spectrogram(stft, title="Spectrogram"):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    _, axis = plt.subplots(1, 1)
    axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
    axis.set_title(title)
    plt.tight_layout()

5. 运行模型

最后,我们运行模型并将分离的源文件存储在一个目录中

作为测试歌曲,我们将使用 MedleyDB 中 NightOwl 的 A Classic Education(知识共享署名-非商业性-相同方式共享 4.0)。该歌曲也位于 MUSDB18-HQ 数据集中的 train 源中。

为了使用不同的歌曲进行测试,可以更改下面的变量名称和 URL 以及参数,以不同的方式测试歌曲分离器。

# We download the audio file from our storage. Feel free to download another file and use audio from a specific path
SAMPLE_SONG = download_asset("tutorial-assets/hdemucs_mix.wav")
waveform, sample_rate = torchaudio.load(SAMPLE_SONG)  # replace SAMPLE_SONG with desired path for different song
waveform = waveform.to(device)
mixture = waveform

# parameters
segment: int = 10
overlap = 0.1

print("Separating track")

ref = waveform.mean(0)
waveform = (waveform - ref.mean()) / ref.std()  # normalization

sources = separate_sources(
    model,
    waveform[None],
    device=device,
    segment=segment,
    overlap=overlap,
)[0]
sources = sources * ref.std() + ref.mean()

sources_list = model.sources
sources = list(sources)

audios = dict(zip(sources_list, sources))
/pytorch/audio/examples/tutorials/hybrid_demucs_tutorial.py:189: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  SAMPLE_SONG = download_asset("tutorial-assets/hdemucs_mix.wav")

  0%|          | 0.00/28.8M [00:00<?, ?B/s]
 38%|###7      | 10.9M/28.8M [00:00<00:00, 114MB/s]
 75%|#######5  | 21.8M/28.8M [00:00<00:00, 111MB/s]
100%|##########| 28.8M/28.8M [00:00<00:00, 130MB/s]
/pytorch/audio/src/torchaudio/_backend/utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.ac.cn/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.
  warnings.warn(
/pytorch/audio/src/torchaudio/_backend/ffmpeg.py:88: UserWarning: torio.io._streaming_media_decoder.StreamingMediaDecoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. The decoding and encoding capabilities of PyTorch for both audio and video are being consolidated into TorchCodec. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
Separating track

5.1 分离音轨

已加载的默认预训练权重集有 4 个分离源:鼓、贝斯、其他和人声,按此顺序排列。它们已存储在字典“audios”中,因此可以在其中访问。对于这四个源,每个源都有一个单独的单元格,将创建音频、频谱图和计算 SDR 分数。SDR 是信噪失真比,本质上是音频轨道“质量”的表示。

N_FFT = 4096
N_HOP = 4
stft = torchaudio.transforms.Spectrogram(
    n_fft=N_FFT,
    hop_length=N_HOP,
    power=None,
)

5.2 音频分段和处理

以下是处理步骤和分段 5 秒音轨,以便输入频谱图并计算相应的 SDR 分数。

def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor, source: str):
    print(
        "SDR score is:",
        separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
    )
    plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
    return Audio(predicted_source, rate=sample_rate)


segment_start = 150
segment_end = 155

frame_start = segment_start * sample_rate
frame_end = segment_end * sample_rate

drums_original = download_asset("tutorial-assets/hdemucs_drums_segment.wav")
bass_original = download_asset("tutorial-assets/hdemucs_bass_segment.wav")
vocals_original = download_asset("tutorial-assets/hdemucs_vocals_segment.wav")
other_original = download_asset("tutorial-assets/hdemucs_other_segment.wav")

drums_spec = audios["drums"][:, frame_start:frame_end].cpu()
drums, sample_rate = torchaudio.load(drums_original)

bass_spec = audios["bass"][:, frame_start:frame_end].cpu()
bass, sample_rate = torchaudio.load(bass_original)

vocals_spec = audios["vocals"][:, frame_start:frame_end].cpu()
vocals, sample_rate = torchaudio.load(vocals_original)

other_spec = audios["other"][:, frame_start:frame_end].cpu()
other, sample_rate = torchaudio.load(other_original)

mix_spec = mixture[:, frame_start:frame_end].cpu()
/pytorch/audio/examples/tutorials/hybrid_demucs_tutorial.py:264: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  drums_original = download_asset("tutorial-assets/hdemucs_drums_segment.wav")

  0%|          | 0.00/1.68M [00:00<?, ?B/s]
100%|##########| 1.68M/1.68M [00:00<00:00, 182MB/s]
/pytorch/audio/examples/tutorials/hybrid_demucs_tutorial.py:265: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  bass_original = download_asset("tutorial-assets/hdemucs_bass_segment.wav")

  0%|          | 0.00/1.68M [00:00<?, ?B/s]
100%|##########| 1.68M/1.68M [00:00<00:00, 141MB/s]
/pytorch/audio/examples/tutorials/hybrid_demucs_tutorial.py:266: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  vocals_original = download_asset("tutorial-assets/hdemucs_vocals_segment.wav")

  0%|          | 0.00/1.68M [00:00<?, ?B/s]
100%|##########| 1.68M/1.68M [00:00<00:00, 228MB/s]
/pytorch/audio/examples/tutorials/hybrid_demucs_tutorial.py:267: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  other_original = download_asset("tutorial-assets/hdemucs_other_segment.wav")

  0%|          | 0.00/1.68M [00:00<?, ?B/s]
100%|##########| 1.68M/1.68M [00:00<00:00, 359MB/s]
/pytorch/audio/src/torchaudio/_backend/utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.ac.cn/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.
  warnings.warn(
/pytorch/audio/src/torchaudio/_backend/ffmpeg.py:88: UserWarning: torio.io._streaming_media_decoder.StreamingMediaDecoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. The decoding and encoding capabilities of PyTorch for both audio and video are being consolidated into TorchCodec. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release.
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)

5.3 频谱图和音频

在接下来的 5 个单元格中,您可以看到带有相应音频的频谱图。使用频谱图可以清晰地可视化音频。

混合剪辑来自原始音轨,其余音轨是模型输出

# Mixture Clip
plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate)
Spectrogram - Mixture


鼓 SDR、频谱图和音频

# Drums Clip
output_results(drums, drums_spec, "drums")
Spectrogram - drums
/pytorch/audio/examples/tutorials/hybrid_demucs_tutorial.py:252: FutureWarning: mir_eval.separation.bss_eval_sources
        Deprecated as of mir_eval version 0.8.
        It will be removed in mir_eval version 0.9.
  separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
SDR score is: 4.964587537534197


贝斯 SDR、频谱图和音频

# Bass Clip
output_results(bass, bass_spec, "bass")
Spectrogram - bass
/pytorch/audio/examples/tutorials/hybrid_demucs_tutorial.py:252: FutureWarning: mir_eval.separation.bss_eval_sources
        Deprecated as of mir_eval version 0.8.
        It will be removed in mir_eval version 0.9.
  separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
SDR score is: 18.905752050002093


人声 SDR、频谱图和音频

# Vocals Audio
output_results(vocals, vocals_spec, "vocals")
Spectrogram - vocals
/pytorch/audio/examples/tutorials/hybrid_demucs_tutorial.py:252: FutureWarning: mir_eval.separation.bss_eval_sources
        Deprecated as of mir_eval version 0.8.
        It will be removed in mir_eval version 0.9.
  separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
SDR score is: 8.792465938875653


其他 SDR、频谱图和音频

# Other Clip
output_results(other, other_spec, "other")
Spectrogram - other
/pytorch/audio/examples/tutorials/hybrid_demucs_tutorial.py:252: FutureWarning: mir_eval.separation.bss_eval_sources
        Deprecated as of mir_eval version 0.8.
        It will be removed in mir_eval version 0.9.
  separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
SDR score is: 8.867021191079175


# Optionally, the full audios can be heard in from running the next 5
# cells. They will take a bit longer to load, so to run simply uncomment
# out the ``Audio`` cells for the respective track to produce the audio
# for the full song.
#

# Full Audio
# Audio(mixture, rate=sample_rate)

# Drums Audio
# Audio(audios["drums"], rate=sample_rate)

# Bass Audio
# Audio(audios["bass"], rate=sample_rate)

# Vocals Audio
# Audio(audios["vocals"], rate=sample_rate)

# Other Audio
# Audio(audios["other"], rate=sample_rate)

脚本总运行时间: (0 分钟 29.065 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源