快捷方式

如何采样视频剪辑

在此示例中,我们将学习如何从视频中采样视频 剪辑。剪辑通常指帧的序列或批次,并且通常作为视频模型的输入传递。

首先,是一些样板代码:我们将从网上下载一个视频,并定义一个绘图工具。您可以忽略这部分,直接跳转到 创建解码器

from typing import Optional
import torch
import requests


# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
    raise RuntimeError(f"Failed to download video. {response.status_code = }.")

raw_video_bytes = response.content


def plot(frames: torch.Tensor, title : Optional[str] = None):
    try:
        from torchvision.utils import make_grid
        from torchvision.transforms.v2.functional import to_pil_image
        import matplotlib.pyplot as plt
    except ImportError:
        print("Cannot plot, please run `pip install torchvision matplotlib`")
        return

    plt.rcParams["savefig.bbox"] = 'tight'
    fig, ax = plt.subplots()
    ax.imshow(to_pil_image(make_grid(frames)))
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    if title is not None:
        ax.set_title(title)
    plt.tight_layout()

创建解码器

从视频中采样剪辑总是从创建一个 VideoDecoder 对象开始。如果您还不熟悉 VideoDecoder,请快速查看:使用 VideoDecoder 解码视频

from torchcodec.decoders import VideoDecoder

# You can also pass a path to a local file!
decoder = VideoDecoder(raw_video_bytes)

采样基础知识

我们现在可以使用解码器来采样剪辑。让我们先看一个简单的例子:所有其他采样器都遵循类似的 API 和原理。我们将使用 clips_at_random_indices() 来采样从随机索引开始的剪辑。

from torchcodec.samplers import clips_at_random_indices

# The samplers RNG is controlled by pytorch's RNG. We set a seed for this
# tutorial to be reproducible across runs, but note that hard-coding a seed for
# a training run is generally not recommended.
torch.manual_seed(0)

clips = clips_at_random_indices(
    decoder,
    num_clips=5,
    num_frames_per_clip=4,
    num_indices_between_frames=3,
)
clips
FrameBatch:
  data (shape): torch.Size([5, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [10.2000, 10.3200, 10.4400, 10.5600],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 9.6000,  9.7200,  9.8400,  9.9600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

采样器的输出是一系列剪辑,表示为 FrameBatch 对象。在此对象中,我们有不同的字段

  • data: 一个 5D uint8 张量,表示帧数据。其形状为 (num_clips, num_frames_per_clip, …),其中 … 是 (C, H, W) 或 (H, W, C),具体取决于 VideoDecoderdimension_order 参数。这通常会传递给模型。

  • pts_seconds: 一个形状为 (num_clips, num_frames_per_clip) 的 2D 浮点张量,给出每个剪辑中每帧的起始时间戳(以秒为单位)。

  • duration_seconds: 一个形状为 (num_clips, num_frames_per_clip) 的 2D 浮点张量,给出每个剪辑中每帧的时长(以秒为单位)。

plot(clips[0].data)
sampling

索引和操作剪辑

剪辑是 FrameBatch 对象,它们支持原生的 PyTorch 索引语义(包括花式索引)。这使得根据给定标准轻松过滤剪辑变得容易。例如,从上面的剪辑中,我们可以轻松地过滤掉那些在特定时间戳之后开始的剪辑

tensor([11.3600, 10.2000,  9.8000,  9.6000,  8.4400], dtype=torch.float64)
clips_starting_after_five_seconds = clips[clip_starts > 5]
clips_starting_after_five_seconds
FrameBatch:
  data (shape): torch.Size([5, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [10.2000, 10.3200, 10.4400, 10.5600],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 9.6000,  9.7200,  9.8400,  9.9600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
every_other_clip = clips[::2]
every_other_clip
FrameBatch:
  data (shape): torch.Size([3, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

注意

获取给定时间戳之后剪辑的一种更自然、更有效的方法是依赖采样范围参数,我们将在后面的 高级参数:采样范围 中介绍。

基于索引和基于时间的采样器

到目前为止,我们使用了 clips_at_random_indices()。Torchcodec 支持其他采样器,它们分为两大类:

基于索引的采样器

基于时间的采样器

所有这些采样器都遵循类似的 API,并且基于时间的采样器具有与基于索引的采样器类似的参数。两种采样器类型在速度方面通常具有相当的性能。

注意

使用基于时间的采样器还是基于索引的采样器更好?基于索引的采样器具有更简单的 API,并且由于索引的离散性质,其行为可能更容易理解和控制。对于具有恒定帧率的视频,基于索引的采样器与基于时间的采样器行为完全相同。但是,对于具有可变帧率的视频(这很常见),依赖索引可能会对视频的某些区域进行欠采样/过采样,这可能导致模型训练时产生不良的副作用。使用基于时间的采样器可确保时间维度的统一采样特性。

高级参数:采样范围

有时,我们可能不希望从整个视频中采样剪辑。我们可能只对在较小区间内开始的剪辑感兴趣。在采样器中,sampling_range_startsampling_range_end 参数控制采样范围:它们定义了允许剪辑 *开始* 的位置。有两件重要的事情需要牢记:

  • sampling_range_end 是一个 *开放* 的上限:剪辑只能在 [sampling_range_start, sampling_range_end) 范围内开始。

  • 由于这些参数定义了剪辑可以开始的位置,剪辑可能包含 sampling_range_end 之后的帧!

from torchcodec.samplers import clips_at_regular_timestamps

clips = clips_at_regular_timestamps(
    decoder,
    seconds_between_clip_starts=1,
    num_frames_per_clip=4,
    seconds_between_frames=0.5,
    sampling_range_start=2,
    sampling_range_end=5
)
clips
FrameBatch:
  data (shape): torch.Size([3, 4, 3, 360, 640])
  pts_seconds: tensor([[2.0000, 2.4800, 3.0000, 3.4800],
        [3.0000, 3.4800, 4.0000, 4.4800],
        [4.0000, 4.4800, 5.0000, 5.4800]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

高级参数:策略

根据视频的时长或持续时间以及采样参数,采样器可能会尝试采样视频末尾之外的帧。policy 参数定义了如何用有效帧替换此类无效帧。

from torchcodec.samplers import clips_at_random_timestamps

end_of_video = decoder.metadata.end_stream_seconds
print(f"{end_of_video = }")
end_of_video = 13.8
torch.manual_seed(0)
clips = clips_at_random_timestamps(
    decoder,
    num_clips=1,
    num_frames_per_clip=5,
    seconds_between_frames=0.4,
    sampling_range_start=end_of_video - 1,
    sampling_range_end=end_of_video,
    policy="repeat_last",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.6800, 13.6800, 13.6800]], dtype=torch.float64)

我们上面看到视频的末尾在 13.8 秒。采样器尝试在时间戳 [13.28, 13.68, 14.08, …] 处采样帧,但 14.08 是一个无效的时间戳,超出了视频末尾。使用“repeat_last”策略(这是默认策略),采样器会简单地重复 13.68 秒的最后一帧来构建剪辑。

另一种策略是“wrap”:采样器然后围绕剪辑进行包装,并在必要时重复前几帧有效帧

torch.manual_seed(0)
clips = clips_at_random_timestamps(
    decoder,
    num_clips=1,
    num_frames_per_clip=5,
    seconds_between_frames=0.4,
    sampling_range_start=end_of_video - 1,
    sampling_range_end=end_of_video,
    policy="wrap",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.2800, 13.6800, 13.2800]], dtype=torch.float64)

默认情况下,sampling_range_end 的值会自动设置为使采样器 *不* 尝试采样视频末尾之外的帧:默认值可确保剪辑在末尾之前足够早地开始。这意味着默认情况下,policy 参数很少生效,大多数用户可能不必过多担心它。

脚本的总运行时间: (0 分钟 0.612 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源