DataCollectorBase¶
- class torchrl.collectors.DataCollectorBase[source]¶
数据收集器的基类。
- async_shutdown(timeout: float | None = None, close_env: bool = True) None [source]¶
当收集器通过 start 方法异步启动时,关闭收集器。
- 参数
timeout (float, optional): 等待收集器关闭的最长时间。 close_env (bool, optional): 如果为 True,收集器将关闭包含的环境。
默认为 True。
另请参阅
- init_updater(*args, **kwargs)[source]¶
使用自定义参数初始化权重更新器。
此方法将参数传递给权重更新器的 init 方法。如果未设置权重更新器,则此方法无效。
- 参数:
*args – 用于权重更新器初始化的位置参数
**kwargs – 用于权重更新器初始化的关键字参数
- start()[source]¶
启动收集器以进行异步数据收集。
此方法启动后台数据收集,允许数据收集和训练解耦。
收集的数据通常存储在收集器初始化期间传入的经验回放缓冲区中。
注意
调用此方法后,务必使用
async_shutdown()
关闭收集器以释放资源。警告
由于其解耦的性质,异步数据收集可能会显著影响训练性能。在使用此模式之前,请确保了解其对您特定算法的影响。
- 抛出:
NotImplementedError – 如果子类未实现。
- update_policy_weights_(policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs) None [source]¶
更新数据收集器的策略权重,支持本地和远程执行上下文。
此方法确保数据收集器使用的策略权重与最新的训练权重同步。它支持本地和远程权重更新,具体取决于数据收集器的配置。本地(下载)更新在远程(上传)更新之前执行,以便可以将权重从服务器传输到子工作器。
- 参数:
policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None) – 要更新的权重。可以是: - TensorDictModuleBase:一个将提取权重的策略模块 - TensorDictBase:一个包含权重的 TensorDict - dict:一个包含权重的常规 dict - None:将尝试从服务器获取权重,使用 _get_server_weights()
worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional) – 需要更新的工作器的标识符。当收集器关联有多个工作器时,此参数才相关。
- 抛出:
TypeError – 如果提供了 worker_ids 但未配置 weight_updater。
注意
用户应扩展 WeightUpdaterBase 类来定制特定用例的权重更新逻辑。不应覆盖此方法。
另请参阅
LocalWeightsUpdaterBase
和RemoteWeightsUpdaterBase()
。