• 文档 >
  • 开始数据收集和存储
快捷方式

使用数据收集和存储入门

作者Vincent Moens

注意

要在 notebook 中运行本教程,请在开头添加一个安装单元格,其中包含:

!pip install tensordict
!pip install torchrl
import tempfile

没有数据就没有学习。在监督学习中,用户习惯于使用 DataLoader 等工具将数据集成到训练循环中。DataLoader 是可迭代对象,它们为您提供将用于训练模型的数据。

TorchRL 以类似的方式处理数据加载问题,尽管它在 RL 库生态系统中出奇地独特。TorchRL 的 DataLoader 被称为 DataCollectors。大多数情况下,数据收集不仅仅是收集原始数据,因为数据需要临时存储在缓冲区(或等效结构,用于 on-policy 算法)中,然后才能被 损失模块 使用。本教程将探讨这两个类。

数据收集器

这里讨论的主要数据收集器是 SyncDataCollector,这是本篇文档的重点。从根本上说,收集器是一个简单的类,负责在环境中执行您的策略、在必要时重置环境,并提供预定义大小的批次。与 环境教程 中演示的 rollout() 方法不同,收集器在连续的数据批次之间不会重置。因此,两个连续的数据批次可能包含来自同一轨迹的元素。

您需要传递给收集器的基本参数是您想要收集的批次大小(frames_per_batch)、迭代器的长度(可能是无限的)、策略和环境。为了简单起见,我们在本例中使用一个虚拟的、随机的策略。

import torch

from torchrl.collectors import SyncDataCollector
from torchrl.envs import GymEnv
from torchrl.envs.utils import RandomPolicy

torch.manual_seed(0)

env = GymEnv("CartPole-v1")
env.set_seed(0)

policy = RandomPolicy(env.action_spec)
collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1)

我们现在预期我们的收集器将提供大小为 200 的批次,无论收集过程中发生什么。换句话说,我们可能在这个批次中有多个轨迹!total_frames 表示收集器应该持续多长时间。值为 -1 将生成一个永不结束的收集器。

让我们迭代收集器,以了解这些数据是什么样的

for data in collector:
    print(data)
    break

正如您所见,我们的数据通过一个我们之前在 环境 rollout 中未见过的 "collector" 子 tensordict 进行了分组,并添加了一些收集器特定的元数据。这对于跟踪轨迹 ID 非常有用。在下面的列表中,每个项目标记了相应转换所属的轨迹编号

print(data["collector", "traj_ids"])

当涉及到编写最先进的算法时,数据收集器非常有用,因为性能通常通过特定技术在给定数量的环境交互中解决问题的能力来衡量(收集器中的 total_frames 参数)。因此,我们示例中的大多数训练循环都如下所示

..code - block::Python

>>> for data in collector:
...     # your algorithm here

回放缓冲区

现在我们已经了解了如何收集数据,我们想知道如何存储它。在 RL 中,典型设置是收集数据,临时存储,并在一段时间后根据某些启发式方法清除:先进先出或其他。典型的伪代码如下所示

..code - block::Python

>>> for data in collector:
...     storage.store(data)
...     for i in range(n_optim):
...         sample = storage.sample()
...         loss_val = loss_fn(sample)
...         loss_val.backward()
...         optim.step() # etc

TorchRL 中存储数据的父类被称为 ReplayBuffer。TorchRL 的回放缓冲区是可组合的:您可以编辑存储类型、采样技术、写入启发式方法或应用的转换。我们将把花哨的东西留给专门的深度教程。通用回放缓冲区只需要知道它必须使用什么存储。一般来说,我们推荐 TensorStorage 子类,它在大多数情况下都能很好地工作。在本教程中,我们将使用 LazyMemmapStorage,它具有两个很好的特性:首先,“惰性”意味着您无需提前明确告知它您的数据是什么样的。其次,它使用 MemoryMappedTensor 作为后端,以高效的方式将数据保存到磁盘。您唯一需要知道的是您希望缓冲区有多大。

from torchrl.data.replay_buffers import LazyMemmapStorage, ReplayBuffer

buffer_scratch_dir = tempfile.TemporaryDirectory().name

buffer = ReplayBuffer(
    storage=LazyMemmapStorage(max_size=1000, scratch_dir=buffer_scratch_dir)
)

可以通过 add()(单个元素)或 extend()(多个元素)方法来填充缓冲区。使用我们刚刚收集的数据,我们一次性初始化和填充缓冲区

indices = buffer.extend(data)

我们可以检查缓冲区现在拥有的元素数量与我们从收集器中获得的数量相同

assert len(buffer) == collector.frames_per_batch

唯一需要知道的是如何从缓冲区中收集数据。当然,这依赖于 sample() 方法。因为我们没有指定无重复采样,所以从缓冲区收集的样本不保证是唯一的

sample = buffer.sample(batch_size=30)
print(sample)

再次,我们的样本看起来与我们从收集器中收集的数据完全相同!

后续步骤

  • 您可以查看其他多进程收集器,例如 MultiSyncDataCollectorMultiaSyncDataCollector

  • 如果您有多个节点用于推理,TorchRL 还提供分布式收集器。请在 API 参考 中查看它们。

  • 请查看专门的 回放缓冲区教程 以了解构建缓冲区时可用的选项,或 API 参考,其中详细介绍了所有功能。回放缓冲区具有无数功能,例如多线程采样、优先体验回放等等……

  • 为了简单起见,我们省略了回放缓冲区的可迭代能力。您可以自己尝试:构建一个缓冲区并在构造函数中指定其批次大小,然后尝试迭代它。这等同于在循环中调用 rb.sample()

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源