• 文档 >
  • 开始数据收集和存储
快捷方式

使用数据收集和存储入门

作者Vincent Moens

注意

要在 notebook 中运行本教程,请在开头添加一个安装单元格,其中包含:

!pip install tensordict
!pip install torchrl
import tempfile

没有数据就没有学习。在监督学习中,用户习惯于使用 DataLoader 等工具将数据集成到其训练循环中。DataLoader 是可迭代对象,它们为您提供将用于训练模型的数据。

TorchRL 在数据加载问题上的方法类似,尽管它在 RL 库生态系统中却出奇地独特。TorchRL 的数据加载器被称为 DataCollectors。大多数时候,数据收集并不仅限于原始数据的收集,因为数据需要在临时缓冲区(或对在线策略算法的等效结构)中存储,然后才能被 损失模块 消耗。本教程将探讨这两个类。

数据收集器

这里讨论的主要数据收集器是 SyncDataCollector,这是本档的重点。从根本上说,收集器是一个简单的类,负责在环境中执行您的策略、在必要时重置环境,并提供预定义大小的批次。与 环境教程 中演示的 rollout() 方法不同,收集器在连续的数据批次之间不会重置。因此,两个连续的数据批次可能包含来自同一轨迹的元素。

您需要传递给收集器的基本参数是您想要收集的批次大小(frames_per_batch)、迭代器的长度(可能无限)、策略和环境。为简单起见,我们将在示例中使用一个模拟的随机策略。

import torch

from torchrl.collectors import SyncDataCollector
from torchrl.envs import GymEnv
from torchrl.envs.utils import RandomPolicy

torch.manual_seed(0)

env = GymEnv("CartPole-v1")
env.set_seed(0)

policy = RandomPolicy(env.action_spec)
collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1)

我们现在期望我们的收集器将提供大小为 200 的批次,无论收集过程中发生什么。换句话说,我们这个批次中可能包含多个轨迹!total_frames 表示收集器应该有多长。值为 -1 将产生一个永不结束的收集器。

让我们迭代收集器,了解一下这些数据是什么样的

for data in collector:
    print(data)
    break

正如您所见,我们的数据被添加了一些收集器特定的元数据,这些元数据被分组在一个 "collector" 子 Tensordict 中,我们在 环境回放 中没有看到这些。这对于跟踪轨迹 ID 非常有用。在下面的列表中,每个项目都标记了对应转换所属的轨迹编号。

print(data["collector", "traj_ids"])

数据收集器在编写最先进的算法时非常有用,因为性能通常通过特定技术在与环境交互的给定次数内解决问题的能力来衡量(收集器中的 total_frames 参数)。因此,我们示例中的大多数训练循环如下所示:

..code - block::Python

>>> for data in collector:
...     # your algorithm here

回放缓冲区

既然我们已经探讨了如何收集数据,我们想知道如何存储它。在 RL 中,典型的情况是数据被收集、临时存储,并在一段时间后根据某种启发式方法清除:先进先出或其他。典型的伪代码如下:

..code - block::Python

>>> for data in collector:
...     storage.store(data)
...     for i in range(n_optim):
...         sample = storage.sample()
...         loss_val = loss_fn(sample)
...         loss_val.backward()
...         optim.step() # etc

TorchRL 中存储数据的父类被称为 ReplayBuffer。TorchRL 的回放缓冲区是可组合的:您可以编辑存储类型、采样技术、写入启发式方法或应用于它们的转换。我们将在专门的深入教程中介绍这些花哨的功能。通用的回放缓冲区只需要知道要使用哪种存储。通常,我们推荐使用 TensorStorage 子类,它在大多数情况下都能正常工作。在本教程中,我们将使用 LazyMemmapStorage,它具有两个优点:首先,“懒惰”,您无需提前显式告诉它您的数据外观。其次,它使用 MemoryMappedTensor 作为后端,以高效的方式将数据保存到磁盘。您唯一需要知道的是您想要的缓冲区大小。

from torchrl.data.replay_buffers import LazyMemmapStorage, ReplayBuffer

buffer_scratch_dir = tempfile.TemporaryDirectory().name

buffer = ReplayBuffer(
    storage=LazyMemmapStorage(max_size=1000, scratch_dir=buffer_scratch_dir)
)

可以通过 add()(单个元素)或 extend()(多个元素)方法来填充缓冲区。使用我们刚刚收集的数据,我们一次性初始化和填充缓冲区。

indices = buffer.extend(data)

我们可以检查缓冲区现在拥有的元素数量与我们从收集器获得的数量相同。

assert len(buffer) == collector.frames_per_batch

唯一需要了解的是如何从缓冲区收集数据。自然,这依赖于 sample() 方法。因为我们没有指定采样必须无重复进行,所以无法保证从缓冲区收集的样本是唯一的。

sample = buffer.sample(batch_size=30)
print(sample)

再次,我们的样本看起来与我们从收集器收集的数据完全一样!

后续步骤

  • 您可以查看其他多进程收集器,例如 MultiSyncDataCollectorMultiaSyncDataCollector

  • TorchRL 还提供分布式收集器,如果您有多个节点用于推理。请在 API 参考 中查看它们。

  • 查看专门的 回放缓冲区教程 以了解构建缓冲区时可用的选项,或者 API 参考,其中详细介绍了所有功能。回放缓冲区拥有无数的功能,例如多线程采样、优先经验回放等等……

  • 为了简单起见,我们省略了回放缓冲区可迭代的能力。您可以自己尝试:构建一个缓冲区并在构造函数中指定其批次大小,然后尝试迭代它。这相当于在一个循环中调用 rb.sample()

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源