快捷方式

使用 VideoDecoder 解码视频

在此示例中,我们将学习如何使用 VideoDecoder 类来解码视频。

首先,一些样板代码:我们将从网上下载一个视频,并定义一个绘图实用程序。您可以忽略这部分,直接跳转到 创建解码器

from typing import Optional
import torch
import requests


# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
    raise RuntimeError(f"Failed to download video. {response.status_code = }.")

raw_video_bytes = response.content


def plot(frames: torch.Tensor, title : Optional[str] = None):
    try:
        from torchvision.utils import make_grid
        from torchvision.transforms.v2.functional import to_pil_image
        import matplotlib.pyplot as plt
    except ImportError:
        print("Cannot plot, please run `pip install torchvision matplotlib`")
        return

    plt.rcParams["savefig.bbox"] = 'tight'
    fig, ax = plt.subplots()
    ax.imshow(to_pil_image(make_grid(frames)))
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    if title is not None:
        ax.set_title(title)
    plt.tight_layout()

创建解码器

我们现在可以从原始(编码)视频字节创建解码器。您当然也可以使用本地视频文件并将路径作为输入,而不是下载视频。

from torchcodec.decoders import VideoDecoder

# You can also pass a path to a local file!
decoder = VideoDecoder(raw_video_bytes)

视频尚未被解码器解码,但我们已经可以通过 metadata 属性访问一些元数据,该属性是一个 VideoStreamMetadata 对象。

print(decoder.metadata)
VideoStreamMetadata:
  duration_seconds_from_header: 13.8
  begin_stream_seconds_from_header: 0.0
  bit_rate: 505790.0
  codec: h264
  stream_index: 0
  begin_stream_seconds_from_content: 0.0
  end_stream_seconds_from_content: 13.8
  width: 640
  height: 360
  num_frames_from_header: 345
  num_frames_from_content: 345
  average_fps_from_header: 25.0
  pixel_aspect_ratio: 1
  duration_seconds: 13.8
  begin_stream_seconds: 0.0
  end_stream_seconds: 13.8
  num_frames: 345
  average_fps: 25.0

通过索引解码器来解码帧

first_frame = decoder[0]  # using a single int index
every_twenty_frame = decoder[0 : -1 : 20]  # using slices

print(f"{first_frame.shape = }")
print(f"{first_frame.dtype = }")
print(f"{every_twenty_frame.shape = }")
print(f"{every_twenty_frame.dtype = }")
first_frame.shape = torch.Size([3, 360, 640])
first_frame.dtype = torch.uint8
every_twenty_frame.shape = torch.Size([18, 3, 360, 640])
every_twenty_frame.dtype = torch.uint8

索引解码器会将帧作为 torch.Tensor 对象返回。默认情况下,帧的形状为 (N, C, H, W),其中 N 是批次大小,C 是通道数,H 是高度,W 是帧的宽度。批次维度 N 仅在我们解码多个帧时存在。可以使用 VideoDecoderdimension_order 参数将维度顺序更改为 N, H, W, C。帧的 dtype 始终是 torch.uint8

注意

如果您需要解码多个帧,我们建议使用批处理方法,因为它们速度更快:get_frames_at()get_frames_in_range()get_frames_played_at()get_frames_played_in_range()。它们在下面进行了描述。

plot(first_frame, "First frame")
First frame
plot(every_twenty_frame, "Every 20 frame")
Every 20 frame

遍历帧

解码器是一个正常的迭代对象,可以像这样进行迭代

for frame in decoder:
    assert (
        isinstance(frame, torch.Tensor)
        and frame.shape == (3, decoder.metadata.height, decoder.metadata.width)
    )

检索帧的 pts 和持续时间

索引解码器会返回纯粹的 torch.Tensor 对象。有时,检索帧的额外信息(例如它们的 pts(显示时间戳)和持续时间)会很有用。这可以通过 get_frame_at()get_frames_at() 方法实现,它们将分别返回 FrameFrameBatch 对象。

last_frame = decoder.get_frame_at(len(decoder) - 1)
print(f"{type(last_frame) = }")
print(last_frame)
type(last_frame) = <class 'torchcodec._frame.Frame'>
Frame:
  data (shape): torch.Size([3, 360, 640])
  pts_seconds: 13.76
  duration_seconds: 0.04
other_frames = decoder.get_frames_at([10, 0, 50])
print(f"{type(other_frames) = }")
print(other_frames)
type(other_frames) = <class 'torchcodec._frame.FrameBatch'>
FrameBatch:
  data (shape): torch.Size([3, 3, 360, 640])
  pts_seconds: tensor([0.4000, 0.0000, 2.0000], dtype=torch.float64)
  duration_seconds: tensor([0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(last_frame.data, "Last frame")
plot(other_frames.data, "Other frames")
  • Last frame
  • Other frames

FrameFrameBatch 都有一个 data 字段,其中包含解码后的张量数据。它们还具有 pts_secondsduration_seconds 字段,对于 Frame 是单个整数,对于 FrameBatch 是 1D torch.Tensor(批次中的每个帧一个值)。

使用基于时间的索引

到目前为止,我们都是根据帧的索引来检索帧的。我们也可以使用 get_frame_played_at()get_frames_played_at() 根据帧的播放时间来检索帧,它们也分别返回 FrameFrameBatch

frame_at_2_seconds = decoder.get_frame_played_at(seconds=2)
print(f"{type(frame_at_2_seconds) = }")
print(frame_at_2_seconds)
type(frame_at_2_seconds) = <class 'torchcodec._frame.Frame'>
Frame:
  data (shape): torch.Size([3, 360, 640])
  pts_seconds: 2.0
  duration_seconds: 0.04
other_frames = decoder.get_frames_played_at(seconds=[10.1, 0.3, 5])
print(f"{type(other_frames) = }")
print(other_frames)
type(other_frames) = <class 'torchcodec._frame.FrameBatch'>
FrameBatch:
  data (shape): torch.Size([3, 3, 360, 640])
  pts_seconds: tensor([10.0800,  0.2800,  5.0000], dtype=torch.float64)
  duration_seconds: tensor([0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(frame_at_2_seconds.data, "Frame played at 2 seconds")
plot(other_frames.data, "Other frames")
  • Frame played at 2 seconds
  • Other frames

脚本总运行时间: (0 分 3.125 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源