VideoRecorder¶
- torchrl.record.VideoRecorder(logger: Logger, tag: str, in_keys: Sequence[NestedKey] | None = None, skip: int | None = None, center_crop: int | None = None, make_grid: bool | None = None, out_keys: Sequence[NestedKey] | None = None, fps: int | None = None, **kwargs) None [源码]¶
视频录制器转换。
将从环境中记录一系列观察,并在需要时将它们写入 Logger 对象。
- 参数:
logger (Logger) – 一个 Logger 实例,视频将写入其中。要将视频保存为 memmap 张量或 mp4 文件,请使用
CSVLogger
类。tag (str) – Logger 中的视频标签。
in_keys (Sequence of NestedKey, optional) – 用于生成视频的读取键。默认为
"pixels"
。skip (int) – 输出视频中的帧间隔。如果转换具有父环境,则默认为
2
,如果不是,则默认为1
。center_crop (int, optional) – 方形中心裁剪的值。
make_grid (bool, optional) – 如果为
True
,则假设提供了形状为 [B x W x H x 3] 的张量,其中 B 是批次大小,并创建一个网格。如果转换具有父环境,则默认为True
,如果不是,则默认为False
。out_keys (sequence of NestedKey, optional) – 目标键。如果未提供,则默认为
in_keys
。fps (int, optional) – 输出视频的每秒帧数。默认为 Logger 预定义的
fps
,如果提供则覆盖它。**kwargs (Dict[str, Any], optional) –
log_video()
的其他关键字参数。
示例
以下示例显示了如何将一个 rollout 保存为视频。首先是一些导入
>>> from torchrl.record import VideoRecorder >>> from torchrl.record.loggers.csv import CSVLogger >>> from torchrl.envs import TransformedEnv, DMControlEnv
视频格式在 Logger 中选择。Wandb 和 Tensorboard 会自行处理。CSV 接受各种视频格式。
>>> logger = CSVLogger(exp_name="cheetah", log_dir="cheetah_videos", video_format="mp4")
一些环境(例如 Atari 游戏)原生返回图像,有些需要用户主动请求。请参阅
GymEnv
或DMControlEnv
,了解如何在这些场景中渲染图像。>>> base_env = DMControlEnv("cheetah", "run", from_pixels=True) >>> env = TransformedEnv(base_env, VideoRecorder(logger=logger, tag="run_video")) >>> env.rollout(100)
所有转换都有一个 dump 函数,大多数情况下是一个 no-op,除了
VideoRecorder
和Compose
,后者会将 dumps 分派给其所有成员。>>> env.transform.dump()
转换也可以在数据集内使用,以保存收集到的视频。与环境中的情况不同,图像将以批次的形式传入。
skip
参数将允许仅在特定间隔保存图像。>>> from torchrl.data.datasets import OpenXExperienceReplay >>> from torchrl.envs import Compose >>> from torchrl.record import VideoRecorder, CSVLogger >>> # Create a logger that saves videos as mp4 using 24 frames per sec >>> logger = CSVLogger("./dump", video_format="mp4", video_fps=24) >>> # We use the VideoRecorder transform to save register the images coming from the batch. >>> # Setting the fps to 12 overrides the one set in the logger, not doing so keeps it unchanged. >>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")], fps=12) >>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False) >>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200, ... download=True, strict_length=False, ... transform=t) >>> # Get a batch of data and visualize it >>> for data in dataset: ... t.dump() ... break
我们的视频可在
./cheetah_videos/cheetah/videos/run_video_0.mp4
下找到!