快捷方式

from torchrl.collectors import SyncDataCollector.. currentmodule:: torchrl.collectors

torchrl.collectors 包

数据收集器在某种程度上等同于 PyTorch 的数据加载器,不同之处在于 (1) 它们从非静态数据源收集数据,(2) 数据是使用模型(可能是正在训练的模型版本)收集的。

TorchRL 的数据收集器接受两个主要参数:一个环境(或一个环境构造函数列表)和一个策略。它们将根据定义的步数迭代地执行环境步和策略查询,然后将收集到的数据堆栈传递给用户。当环境达到完成状态以及/或经过预定义的步数后,它们将被重置。

由于数据收集是一个潜在的计算密集型过程,因此配置执行超参数至关重要。要考虑的第一个参数是数据收集应该与优化步骤串行进行还是并行进行。SyncDataCollector 类将在训练工作器上执行数据收集。MultiSyncDataCollector 将工作负载分配给多个工作器,并聚合将传递给训练工作器的结果。最后,MultiaSyncDataCollector 将在多个工作器上执行数据收集,并传递它可以收集到的第一批结果。此执行将持续不断地与网络训练同时进行:这意味着用于数据收集的策略权重可能会略微滞后于训练工作器上的策略配置。因此,虽然此类可能是收集数据的最快方法,但其代价是仅适用于可以接受异步收集数据的场景(例如,离策略强化学习或课程强化学习)。对于远程执行的 rollout(MultiSyncDataCollectorMultiaSyncDataCollector),有必要使用 `collector.update_policy_weights_()` 或在构造函数中设置 `update_at_each_batch=True` 来同步远程策略的权重与训练工作器上的权重。

要考虑的第二个参数(在远程设置中)是数据将被收集的设备以及将执行环境和策略操作的设备。例如,在 CPU 上执行的策略可能比在 CUDA 上执行的策略慢。当多个推理工作器同时运行时,跨可用设备分派计算工作负载可以加快收集速度或避免 OOM 错误。最后,批次大小的选择和传递设备(即数据等待传递给收集工作器时存储数据的设备)也可能影响内存管理。要控制的关键参数是 devices,它控制执行设备(即策略的设备),以及 storing_device,它控制回滚期间存储环境和数据的设备。一个好的经验法则是通常使用相同的设备进行存储和计算,当只传递 devices 参数时,这就是默认行为。

除了这些计算参数之外,用户还可以选择配置以下参数

  • max_frames_per_traj:调用 env.reset() 之前的帧数

  • frames_per_batch:每次迭代收集器传递的帧数

  • init_random_frames:随机步数(调用 env.rand_step() 的步数)

  • reset_at_each_iter:如果为 True,则在每次批次收集后重置环境

  • split_trajs:如果为 True,则将轨迹拆分并以填充的 tensordict 的形式传递,并带有 "mask" 键,该键指向表示有效值的布尔掩码。

  • exploration_type:与策略一起使用的探索策略。

  • reset_when_done:达到完成状态时是否重置环境。

收集器和批次大小

由于每个收集器都有自己组织内部运行的环境的方式,因此根据收集器的具体特性,数据会带有不同的批次大小。下表总结了收集数据时的情况

SyncDataCollector

MultiSyncDataCollector (n=B)

MultiaSyncDataCollector (n=B)

cat_results

NA

“stack”

0

-1

NA

单个环境

[T]

[B, T]

[B*(T//B)

[B*(T//B)]

[T]

批处理环境 (n=P)

[P, T]

[B, P, T]

[B * P, T]

[P, T * B]

[P, T]

在所有这些情况下,最后一个维度(T 表示 time)会进行调整,以使批次大小等于传递给收集器的 frames_per_batch 参数。

警告

MultiSyncDataCollector 不应与 cat_results=0 一起使用,因为数据将沿着批次维度堆叠,对于批处理环境,或者沿着时间维度堆叠,对于单个环境,这可能会在两者之间切换时造成混淆。cat_results="stack" 是与环境交互的更好、更一致的方式,因为它会保持每个维度分开,并提供与配置、收集器类和其他组件更好的互换性。

虽然 MultiSyncDataCollector 有一个对应于运行的子收集器数量的维度(B),但 MultiaSyncDataCollector 没有。考虑到 MultiaSyncDataCollector 是基于“先到先得”的原则传递数据批次的,而 MultiSyncDataCollector 在传递数据之前会从每个子收集器收集数据,因此这一点很容易理解。

收集器和策略副本

在将策略传递给收集器时,我们可以选择策略运行的设备。这可以用于将策略的训练版本保留在一个设备上,并将推理版本保留在另一个设备上。例如,如果您有两个 CUDA 设备,最好在一个设备上进行训练,在另一个设备上执行策略进行推理。在这种情况下,可以使用 `update_policy_weights_()` 将参数从一个设备复制到另一个设备(如果不需要复制,此方法将不起作用)。

由于目标是避免显式调用 `policy.to(policy_device)`,收集器将在实例化期间执行策略结构的深拷贝并在必要时复制放在新设备上的参数。由于并非所有策略都支持深拷贝(例如,使用 CUDA 图或依赖第三方库的策略),我们会尽量限制深拷贝的执行情况。下表显示了何时会发生这种情况。

../_images/collector-copy.png

收集器中的策略复制决策树。

分布式环境中的权重同步

在分布式和多进程环境中,确保所有策略实例都与最新的训练权重同步对于性能一致至关重要。该 API 引入了一个灵活且可扩展的机制,用于在不同设备和进程之间更新策略权重,以适应各种部署场景。

使用 WeightUpdaters 发送和接收模型权重

权重同步过程通过一个专用的扩展点来实现:WeightUpdaterBase。这个基类提供了一个结构化接口来实施自定义权重更新逻辑,允许用户根据自身特定需求定制同步过程。

WeightUpdaterBase 负责将策略权重分发给策略或远程推理工作器,并在必要时从服务器格式化/收集权重。每个收集器——服务器或工作器——都应该有一个 `WeightUpdaterBase` 实例来处理与策略的权重同步。即使是最简单的收集器也使用 VanillaWeightUpdater 实例来更新策略的状态字典(假设它是一个 Module 实例)。

扩展 Updater 类

为了适应不同的用例,API 允许用户使用自定义实现来扩展 updater 类。目标是能够自定义权重同步策略,同时保持收集器和策略实现不变。这种灵活性在涉及复杂网络架构或专用硬件设置的场景中特别有利。通过实现这些基类中的抽象方法,用户可以定义如何检索、转换和应用权重,从而确保与现有基础设施的无缝集成。

WeightUpdaterBase()

一个用于在远程推理工作器上更新本地策略权重的基类。

VanillaWeightUpdater(*[, weight_getter])

WeightUpdaterBase 的一个简单实现,用于更新本地策略权重。

MultiProcessedWeightUpdater(*, ...)

一个用于在多个进程或设备之间同步策略权重的远程权重更新器。

RayWeightUpdater(policy_weights, ...[, ...])

一个使用 Ray 在远程工作器之间同步策略权重的远程权重更新器。

RPCWeightUpdater(collector_infos, ...)

一个使用 RPC 在远程工作器之间同步策略权重的远程权重更新器。

DistributedWeightUpdater(store, ...)

一个用于在分布式工作器之间同步策略权重的远程权重更新器。

收集器和回放缓冲区互操作性

在需要从回放缓冲区采样单个转换的最简单场景中,几乎不需要关注收集器的构建方式。收集后的数据展平将是填充存储的足够预处理步骤

>>> memory = ReplayBuffer(
...     storage=LazyTensorStorage(N),
...     transform=lambda data: data.reshape(-1))
>>> for data in collector:
...     memory.extend(data)

如果需要收集轨迹切片,推荐的方法是创建一个多维缓冲区并使用 SliceSampler 采样器类进行采样。必须确保传递给缓冲区的数据形状正确,并且 timebatch 维度清晰地分开。在实践中,以下配置将起作用

>>> # Single environment: no need for a multi-dimensional buffer
>>> memory = ReplayBuffer(
...     storage=LazyTensorStorage(N),
...     sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
>>> for data in collector:
...     memory.extend(data)
>>> # Batched environments: a multi-dim buffer is required
>>> memory = ReplayBuffer(
...     storage=LazyTensorStorage(N, ndim=2),
...     sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> env = ParallelEnv(4, make_env)
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
>>> for data in collector:
...     memory.extend(data)
>>> # MultiSyncDataCollector + regular env: behaves like a ParallelEnv if cat_results="stack"
>>> memory = ReplayBuffer(
...     storage=LazyTensorStorage(N, ndim=2),
...     sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = MultiSyncDataCollector([make_env] * 4,
...     policy,
...     frames_per_batch=N,
...     total_frames=-1,
...     cat_results="stack")
>>> for data in collector:
...     memory.extend(data)
>>> # MultiSyncDataCollector + parallel env: the ndim must be adapted accordingly
>>> memory = ReplayBuffer(
...     storage=LazyTensorStorage(N, ndim=3),
...     sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = MultiSyncDataCollector([ParallelEnv(2, make_env)] * 4,
...     policy,
...     frames_per_batch=N,
...     total_frames=-1,
...     cat_results="stack")
>>> for data in collector:
...     memory.extend(data)

目前尚不支持使用通过 MultiSyncDataCollector 采样轨迹的回放缓冲区,因为数据批次可能来自任何工作器,并且在大多数情况下,写入缓冲区中的连续批次不会来自同一来源(从而中断轨迹)。

异步运行收集器

将回放缓冲区传递给收集器使我们能够开始收集并摆脱收集器的迭代性质。如果您想在后台运行数据收集器,只需运行 `start()`

>>> collector = SyncDataCollector(..., replay_buffer=rb) # pass your replay buffer
>>> collector.start()
>>> # little pause
>>> time.sleep(10)
>>> # Start training
>>> for i in range(optim_steps):
...     data = rb.sample()  # Sampling from the replay buffer
...     # rest of the training loop

单进程收集器(SyncDataCollector)将使用多线程运行进程,因此请注意 Python 的 GIL 和相关的多线程限制。

另一方面,多进程收集器将让子进程自行处理缓冲区的填充,这真正将数据收集和训练解耦。

使用 `start()` 启动的数据收集器应使用 `async_shutdown()` 关闭。

警告

异步运行收集器会将收集与训练解耦,这意味着训练性能可能会因硬件、负载和其他因素而有很大差异(尽管通常预计会提供显著的加速)。请确保您了解这可能如何影响您的算法以及这是否是一件可以做的事情!(例如,像 PPO 这样的同策略算法不应异步运行,除非经过充分的基准测试)。

单节点数据收集器

DataCollectorBase()

数据收集器的基类。

SyncDataCollector(create_env_fn[, policy, ...])

用于强化学习问题的通用数据收集器。

MultiSyncDataCollector(create_env_fn[, ...])

在单独的进程中同步运行指定数量的 DataCollectors。

MultiaSyncDataCollector(*args, **kwargs)

在单独的进程中异步运行指定数量的 DataCollectors。

aSyncDataCollector(create_env_fn[, policy, ...])

在单独的进程中运行单个 DataCollector。

分布式数据收集器

TorchRL 提供了一套分布式数据收集器。这些工具支持多种后端('gloo''nccl''mpi',使用 DistributedDataCollector 或 PyTorch RPC,使用 RPCDataCollector)和启动器('ray'、`submitit` 或 `torch.multiprocessing`)。它们可以在同步或异步模式下,在单节点或跨节点上高效使用。

资源:在专用文件夹中查找这些收集器的示例。

注意

选择子收集器:所有分布式收集器都支持各种单机收集器。人们可能会问为什么还要使用 MultiSyncDataCollectorParallelEnv。总的来说,多进程收集器的 IO 开销比需要每步通信的并行环境要低。然而,模型规范也在朝着相反的方向发挥作用,因为使用并行环境将导致策略(和/或转换)执行速度更快,因为这些操作将被向量化。

注意

选择收集器(或并行环境)的设备:通过并行环境和在 CPU 上执行的多进程环境共享内存缓冲区来实现进程间数据共享。根据所用机器的功能,这可能比在 GPU 上共享数据(由 CUDA 驱动程序原生支持)慢得令人难以接受。在实践中,这意味着在构建并行环境或收集器时使用 `device="cpu"` 关键字参数可能比在可用时使用 `device="cuda"` 慢。

注意

鉴于该库有许多可选依赖项(例如 Gym、Gymnasium 和许多其他),在多进程/分布式设置中,警告可能会变得非常烦人。默认情况下,TorchRL 会过滤掉这些子进程中的警告。如果仍希望看到这些警告,可以通过设置 `torchrl.filter_warnings_subprocess=False` 来显示它们。

DistributedDataCollector(create_env_fn, ...)

一个使用 torch.distributed 后端的分布式数据收集器。

RPCDataCollector(create_env_fn, policy, ...)

一个基于 RPC 的分布式数据收集器。

DistributedSyncDataCollector(create_env_fn, ...)

一个使用 torch.distributed 后端的分布式同步数据收集器。

submitit_delayed_launcher(num_jobs[, ...])

submitit 的延迟启动器。

RayCollector(create_env_fn, policy, ...[, ...])

一个使用 Ray 后端的分布式数据收集器。

辅助函数

split_trajectories(rollout_tensordict, *[, ...])

一个用于轨迹分离的实用函数。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源