快捷方式

DataCollectorBase

class torchrl.collectors.DataCollectorBase[source]

数据收集器的基类。

async_shutdown(timeout: float | None = None, close_env: bool = True) None[source]

当收集器通过 start 方法异步启动时,关闭收集器。

参数

timeout (float, optional): 等待收集器关闭的最长时间。 close_env (bool, optional): 如果为 True,收集器将关闭包含的环境。

默认为 True

另请参阅

start()

init_updater(*args, **kwargs)[source]

使用自定义参数初始化权重更新器。

此方法将参数传递给权重更新器的 init 方法。如果未设置权重更新器,则此方法无效。

参数:
  • *args – 用于权重更新器初始化的位置参数

  • **kwargs – 用于权重更新器初始化的关键字参数

pause()[source]

上下文管理器,如果收集器正在自由运行,则暂停收集器。

start()[source]

启动收集器以进行异步数据收集。

此方法启动后台数据收集,允许数据收集和训练解耦。

收集的数据通常存储在收集器初始化期间传入的经验回放缓冲区中。

注意

调用此方法后,务必使用 async_shutdown() 关闭收集器以释放资源。

警告

由于其解耦的性质,异步数据收集可能会显著影响训练性能。在使用此模式之前,请确保了解其对您特定算法的影响。

抛出:

NotImplementedError – 如果子类未实现。

update_policy_weights_(policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs) None[source]

更新数据收集器的策略权重,支持本地和远程执行上下文。

此方法确保数据收集器使用的策略权重与最新的训练权重同步。它支持本地和远程权重更新,具体取决于数据收集器的配置。本地(下载)更新在远程(上传)更新之前执行,以便可以将权重从服务器传输到子工作器。

参数:
  • policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None) – 要更新的权重。可以是: - TensorDictModuleBase:一个将提取权重的策略模块 - TensorDictBase:一个包含权重的 TensorDict - dict:一个包含权重的常规 dict - None:将尝试从服务器获取权重,使用 _get_server_weights()

  • worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional) – 需要更新的工作器的标识符。当收集器关联有多个工作器时,此参数才相关。

抛出:

TypeError – 如果提供了 worker_ids 但未配置 weight_updater

注意

用户应扩展 WeightUpdaterBase 类来定制特定用例的权重更新逻辑。不应覆盖此方法。

另请参阅

LocalWeightsUpdaterBaseRemoteWeightsUpdaterBase()

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源