并行视频解码:多进程与多线程¶
在本教程中,我们将探讨并行解码大量视频帧的不同方法。我们将比较三种并行策略:
基于 FFmpeg 的并行:使用 FFmpeg 的内部线程功能
Joblib 多进程:将工作分配给多个进程
Joblib 多线程:在单个进程中使用多个线程
我们将使用 joblib 进行并行处理,因为它提供了非常方便的 API 来将工作分配到多个进程或线程。但这只是 Python 中并行处理工作的一种方式。你绝对可以使用其他线程或进程池管理器。
让我们先定义一些用于基准测试和数据处理的实用函数。我们还将下载一个视频并将其重复多次以创建一个更长的版本。这模拟了处理需要高效处理的长视频。你可以忽略这部分,直接跳转到 帧采样策略。
from typing import List
import torch
import requests
import tempfile
from pathlib import Path
import subprocess
from time import perf_counter_ns
from datetime import timedelta
from joblib import Parallel, delayed, cpu_count
from torchcodec.decoders import VideoDecoder
def bench(f, *args, num_exp=3, warmup=1, **kwargs):
"""Benchmark a function by running it multiple times and measuring execution time."""
for _ in range(warmup):
f(*args, **kwargs)
times = []
for _ in range(num_exp):
start = perf_counter_ns()
result = f(*args, **kwargs)
end = perf_counter_ns()
times.append(end - start)
return torch.tensor(times).float(), result
def report_stats(times, unit="s"):
"""Report median and standard deviation of benchmark times."""
mul = {
"ns": 1,
"µs": 1e-3,
"ms": 1e-6,
"s": 1e-9,
}[unit]
times = times * mul
std = times.std().item()
med = times.median().item()
print(f"median = {med:.2f}{unit} ± {std:.2f}")
return med
def split_indices(indices: List[int], num_chunks: int) -> List[List[int]]:
"""Split a list of indices into approximately equal chunks."""
chunk_size = len(indices) // num_chunks
chunks = []
for i in range(num_chunks - 1):
chunks.append(indices[i * chunk_size:(i + 1) * chunk_size])
# Last chunk may be slightly larger
chunks.append(indices[(num_chunks - 1) * chunk_size:])
return chunks
def generate_long_video(temp_dir: str):
# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
raise RuntimeError(f"Failed to download video. {response.status_code = }.")
short_video_path = Path(temp_dir) / "short_video.mp4"
with open(short_video_path, 'wb') as f:
for chunk in response.iter_content():
f.write(chunk)
# Create a longer video by repeating the short one 50 times
long_video_path = Path(temp_dir) / "long_video.mp4"
ffmpeg_command = [
"ffmpeg", "-y",
"-stream_loop", "49", # repeat video 50 times
"-i", str(short_video_path),
"-c", "copy",
str(long_video_path)
]
subprocess.run(ffmpeg_command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return short_video_path, long_video_path
temp_dir = tempfile.mkdtemp()
short_video_path, long_video_path = generate_long_video(temp_dir)
decoder = VideoDecoder(long_video_path, seek_mode="approximate")
metadata = decoder.metadata
short_duration = timedelta(seconds=VideoDecoder(short_video_path).metadata.duration_seconds)
long_duration = timedelta(seconds=metadata.duration_seconds)
print(f"Original video duration: {int(short_duration.total_seconds() // 60)}m{int(short_duration.total_seconds() % 60):02d}s")
print(f"Long video duration: {int(long_duration.total_seconds() // 60)}m{int(long_duration.total_seconds() % 60):02d}s")
print(f"Video resolution: {metadata.width}x{metadata.height}")
print(f"Average FPS: {metadata.average_fps:.1f}")
print(f"Total frames: {metadata.num_frames}")
Original video duration: 0m13s
Long video duration: 11m30s
Video resolution: 640x360
Average FPS: 25.0
Total frames: 17250
帧采样策略¶
在本教程中,我们将从长视频中每 2 秒采样一帧。这模拟了一个常见的场景,即你需要处理一部分帧用于 LLM 推理。
TARGET_FPS = 2
step = max(1, round(metadata.average_fps / TARGET_FPS))
all_indices = list(range(0, metadata.num_frames, step))
print(f"Sampling 1 frame every {TARGET_FPS} seconds")
print(f"We'll skip every {step} frames")
print(f"Total frames to decode: {len(all_indices)}")
Sampling 1 frame every 2 seconds
We'll skip every 12 frames
Total frames to decode: 1438
方法 1:顺序解码(基准)¶
让我们从顺序方法开始作为我们的基准。这会逐帧处理,没有任何并行化。
def decode_sequentially(indices: List[int], video_path=long_video_path):
"""Decode frames sequentially using a single decoder instance."""
decoder = VideoDecoder(video_path, seek_mode="approximate")
return decoder.get_frames_at(indices)
times, result_sequential = bench(decode_sequentially, all_indices)
sequential_time = report_stats(times, unit="s")
median = 14.31s ± 0.02
方法 2:基于 FFmpeg 的并行¶
FFmpeg 具有内置的多线程功能,可以通过 num_ffmpeg_threads
参数进行控制。此方法利用 FFmpeg 内部的多个线程来加速解码操作。
def decode_with_ffmpeg_parallelism(
indices: List[int],
num_threads: int,
video_path=long_video_path
):
"""Decode frames using FFmpeg's internal threading."""
decoder = VideoDecoder(video_path, num_ffmpeg_threads=num_threads, seek_mode="approximate")
return decoder.get_frames_at(indices)
NUM_CPUS = cpu_count()
times, result_ffmpeg = bench(decode_with_ffmpeg_parallelism, all_indices, num_threads=NUM_CPUS)
ffmpeg_time = report_stats(times, unit="s")
speedup = sequential_time / ffmpeg_time
print(f"Speedup compared to sequential: {speedup:.2f}x with {NUM_CPUS} FFmpeg threads.")
median = 7.09s ± 0.02
Speedup compared to sequential: 2.02x with 16 FFmpeg threads.
方法 3:多进程¶
基于进程的并行将工作分配给多个 Python 进程。
def decode_with_multiprocessing(
indices: List[int],
num_processes: int,
video_path=long_video_path
):
"""Decode frames using multiple processes with joblib."""
chunks = split_indices(indices, num_chunks=num_processes)
# loky is a multi-processing backend for joblib: https://github.com/joblib/loky
results = Parallel(n_jobs=num_processes, backend="loky", verbose=0)(
delayed(decode_sequentially)(chunk, video_path) for chunk in chunks
)
return torch.cat([frame_batch.data for frame_batch in results], dim=0)
times, result_multiprocessing = bench(decode_with_multiprocessing, all_indices, num_processes=NUM_CPUS)
multiprocessing_time = report_stats(times, unit="s")
speedup = sequential_time / multiprocessing_time
print(f"Speedup compared to sequential: {speedup:.2f}x with {NUM_CPUS} processes.")
median = 5.39s ± 0.01
Speedup compared to sequential: 2.65x with 16 processes.
方法 4:Joblib 多线程¶
基于线程的并行在单个进程中使用多个线程。TorchCodec 会释放 GIL,因此这可能非常有效。
def decode_with_multithreading(
indices: List[int],
num_threads: int,
video_path=long_video_path
):
"""Decode frames using multiple threads with joblib."""
chunks = split_indices(indices, num_chunks=num_threads)
results = Parallel(n_jobs=num_threads, prefer="threads", verbose=0)(
delayed(decode_sequentially)(chunk, video_path) for chunk in chunks
)
# Concatenate results from all threads
return torch.cat([frame_batch.data for frame_batch in results], dim=0)
times, result_multithreading = bench(decode_with_multithreading, all_indices, num_threads=NUM_CPUS)
multithreading_time = report_stats(times, unit="s")
speedup = sequential_time / multithreading_time
print(f"Speedup compared to sequential: {speedup:.2f}x with {NUM_CPUS} threads.")
median = 1.94s ± 0.01
Speedup compared to sequential: 7.38x with 16 threads.
验证和正确性检查¶
让我们验证所有方法是否都产生了相同的结果。
torch.testing.assert_close(result_sequential.data, result_ffmpeg.data, atol=0, rtol=0)
torch.testing.assert_close(result_sequential.data, result_multiprocessing, atol=0, rtol=0)
torch.testing.assert_close(result_sequential.data, result_multithreading, atol=0, rtol=0)
print("All good!")
All good!
import shutil
shutil.rmtree(temp_dir)
脚本总运行时间: (2 分钟 1.981 秒)