如何采样视频剪辑¶
在此示例中,我们将学习如何从视频中采样视频 剪辑。剪辑通常指帧的序列或批次,并且通常作为视频模型的输入传递。
首先,是一些样板代码:我们将从网上下载一个视频,并定义一个绘图工具。您可以忽略这部分,直接跳转到 创建解码器。
from typing import Optional
import torch
import requests
# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
raise RuntimeError(f"Failed to download video. {response.status_code = }.")
raw_video_bytes = response.content
def plot(frames: torch.Tensor, title : Optional[str] = None):
try:
from torchvision.utils import make_grid
from torchvision.transforms.v2.functional import to_pil_image
import matplotlib.pyplot as plt
except ImportError:
print("Cannot plot, please run `pip install torchvision matplotlib`")
return
plt.rcParams["savefig.bbox"] = 'tight'
fig, ax = plt.subplots()
ax.imshow(to_pil_image(make_grid(frames)))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if title is not None:
ax.set_title(title)
plt.tight_layout()
创建解码器¶
从视频中采样剪辑总是从创建一个 VideoDecoder
对象开始。如果您还不熟悉 VideoDecoder
,请快速查看:使用 VideoDecoder 解码视频。
from torchcodec.decoders import VideoDecoder
# You can also pass a path to a local file!
decoder = VideoDecoder(raw_video_bytes)
采样基础知识¶
我们现在可以使用解码器来采样剪辑。让我们先看一个简单的例子:所有其他采样器都遵循类似的 API 和原理。我们将使用 clips_at_random_indices()
来采样从随机索引开始的剪辑。
from torchcodec.samplers import clips_at_random_indices
# The samplers RNG is controlled by pytorch's RNG. We set a seed for this
# tutorial to be reproducible across runs, but note that hard-coding a seed for
# a training run is generally not recommended.
torch.manual_seed(0)
clips = clips_at_random_indices(
decoder,
num_clips=5,
num_frames_per_clip=4,
num_indices_between_frames=3,
)
clips
FrameBatch:
data (shape): torch.Size([5, 4, 3, 360, 640])
pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
[10.2000, 10.3200, 10.4400, 10.5600],
[ 9.8000, 9.9200, 10.0400, 10.1600],
[ 9.6000, 9.7200, 9.8400, 9.9600],
[ 8.4400, 8.5600, 8.6800, 8.8000]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
采样器的输出是一系列剪辑,表示为 FrameBatch
对象。在此对象中,我们有不同的字段
data
: 一个 5D uint8 张量,表示帧数据。其形状为 (num_clips, num_frames_per_clip, …),其中 … 是 (C, H, W) 或 (H, W, C),具体取决于VideoDecoder
的dimension_order
参数。这通常会传递给模型。pts_seconds
: 一个形状为 (num_clips, num_frames_per_clip) 的 2D 浮点张量,给出每个剪辑中每帧的起始时间戳(以秒为单位)。duration_seconds
: 一个形状为 (num_clips, num_frames_per_clip) 的 2D 浮点张量,给出每个剪辑中每帧的时长(以秒为单位)。
plot(clips[0].data)

索引和操作剪辑¶
剪辑是 FrameBatch
对象,它们支持原生的 PyTorch 索引语义(包括花式索引)。这使得根据给定标准轻松过滤剪辑变得容易。例如,从上面的剪辑中,我们可以轻松地过滤掉那些在特定时间戳之后开始的剪辑
clip_starts = clips.pts_seconds[:, 0]
clip_starts
tensor([11.3600, 10.2000, 9.8000, 9.6000, 8.4400], dtype=torch.float64)
clips_starting_after_five_seconds = clips[clip_starts > 5]
clips_starting_after_five_seconds
FrameBatch:
data (shape): torch.Size([5, 4, 3, 360, 640])
pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
[10.2000, 10.3200, 10.4400, 10.5600],
[ 9.8000, 9.9200, 10.0400, 10.1600],
[ 9.6000, 9.7200, 9.8400, 9.9600],
[ 8.4400, 8.5600, 8.6800, 8.8000]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
every_other_clip = clips[::2]
every_other_clip
FrameBatch:
data (shape): torch.Size([3, 4, 3, 360, 640])
pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
[ 9.8000, 9.9200, 10.0400, 10.1600],
[ 8.4400, 8.5600, 8.6800, 8.8000]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
注意
获取给定时间戳之后剪辑的一种更自然、更有效的方法是依赖采样范围参数,我们将在后面的 高级参数:采样范围 中介绍。
基于索引和基于时间的采样器¶
到目前为止,我们使用了 clips_at_random_indices()
。Torchcodec 支持其他采样器,它们分为两大类:
基于索引的采样器
基于时间的采样器
所有这些采样器都遵循类似的 API,并且基于时间的采样器具有与基于索引的采样器类似的参数。两种采样器类型在速度方面通常具有相当的性能。
注意
使用基于时间的采样器还是基于索引的采样器更好?基于索引的采样器具有更简单的 API,并且由于索引的离散性质,其行为可能更容易理解和控制。对于具有恒定帧率的视频,基于索引的采样器与基于时间的采样器行为完全相同。但是,对于具有可变帧率的视频(这很常见),依赖索引可能会对视频的某些区域进行欠采样/过采样,这可能导致模型训练时产生不良的副作用。使用基于时间的采样器可确保时间维度的统一采样特性。
高级参数:采样范围¶
有时,我们可能不希望从整个视频中采样剪辑。我们可能只对在较小区间内开始的剪辑感兴趣。在采样器中,sampling_range_start
和 sampling_range_end
参数控制采样范围:它们定义了允许剪辑 *开始* 的位置。有两件重要的事情需要牢记:
sampling_range_end
是一个 *开放* 的上限:剪辑只能在 [sampling_range_start, sampling_range_end) 范围内开始。由于这些参数定义了剪辑可以开始的位置,剪辑可能包含
sampling_range_end
之后的帧!
from torchcodec.samplers import clips_at_regular_timestamps
clips = clips_at_regular_timestamps(
decoder,
seconds_between_clip_starts=1,
num_frames_per_clip=4,
seconds_between_frames=0.5,
sampling_range_start=2,
sampling_range_end=5
)
clips
FrameBatch:
data (shape): torch.Size([3, 4, 3, 360, 640])
pts_seconds: tensor([[2.0000, 2.4800, 3.0000, 3.4800],
[3.0000, 3.4800, 4.0000, 4.4800],
[4.0000, 4.4800, 5.0000, 5.4800]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
高级参数:策略¶
根据视频的时长或持续时间以及采样参数,采样器可能会尝试采样视频末尾之外的帧。policy
参数定义了如何用有效帧替换此类无效帧。
from torchcodec.samplers import clips_at_random_timestamps
end_of_video = decoder.metadata.end_stream_seconds
print(f"{end_of_video = }")
end_of_video = 13.8
torch.manual_seed(0)
clips = clips_at_random_timestamps(
decoder,
num_clips=1,
num_frames_per_clip=5,
seconds_between_frames=0.4,
sampling_range_start=end_of_video - 1,
sampling_range_end=end_of_video,
policy="repeat_last",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.6800, 13.6800, 13.6800]], dtype=torch.float64)
我们上面看到视频的末尾在 13.8 秒。采样器尝试在时间戳 [13.28, 13.68, 14.08, …] 处采样帧,但 14.08 是一个无效的时间戳,超出了视频末尾。使用“repeat_last”策略(这是默认策略),采样器会简单地重复 13.68 秒的最后一帧来构建剪辑。
另一种策略是“wrap”:采样器然后围绕剪辑进行包装,并在必要时重复前几帧有效帧
torch.manual_seed(0)
clips = clips_at_random_timestamps(
decoder,
num_clips=1,
num_frames_per_clip=5,
seconds_between_frames=0.4,
sampling_range_start=end_of_video - 1,
sampling_range_end=end_of_video,
policy="wrap",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.2800, 13.6800, 13.2800]], dtype=torch.float64)
默认情况下,sampling_range_end
的值会自动设置为使采样器 *不* 尝试采样视频末尾之外的帧:默认值可确保剪辑在末尾之前足够早地开始。这意味着默认情况下,policy 参数很少生效,大多数用户可能不必过多担心它。
脚本的总运行时间: (0 分钟 0.612 秒)