DistributedWeightUpdater¶
- class torchrl.collectors.distributed.DistributedWeightUpdater(store: dict[str, str], policy_weights: TensorDictBase, num_workers: int, sync: bool)[source]¶
用于同步分布式工作程序之间策略权重的远程权重更新器。
DistributedWeightUpdater 类提供了一种跨分布式推理工作程序更新策略权重的机制。它旨在与
DistributedDataCollector
配合使用,以确保每个工作程序都接收到最新的策略权重。此类通常用于分布式数据收集场景,其中需要使多个工作程序与中心策略权重保持同步。- 参数:
store (dict[str, str]) – 用于服务器和分布式工作程序之间通信的类字典存储。
policy_weights (TensorDictBase) – 需要分发给工作程序的当前策略权重。
num_workers (int) – 将接收更新策略的工作程序的数量。
sync (bool) – 如果为
True
,则同步发生(服务器等待工作程序完成更新以重新开始运行)。
- update_weights()¶
更新指定或所有分布式工作程序上的权重。
注意
此类假定服务器权重无需任何额外处理即可直接应用于分布式工作程序。如果您的用例需要更复杂的权重映射或同步逻辑,请考虑使用自定义实现扩展 WeightUpdaterBase。
- 抛出:
RuntimeError – 如果工作程序 rank 小于 1 或从存储返回的状态不是“updated”。
- all_worker_ids() list[int] | list[torch.device] [source]¶
获取所有工作进程 ID 的列表。
默认返回 None。子类应覆盖以返回实际的工作进程 ID。
- 返回:
工作进程 ID 列表或 None。
- 返回类型:
list[int] | list[torch.device] | None
- property collector: torch.collector.DataCollectorBase¶
接收器的收集器或容器。
如果容器超出范围或未设置,则返回None。
- classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None ¶
可选的类方法,用于从策略创建权重更新器实例。
子类可以实现此方法以提供基于策略的自定义初始化逻辑。如果实现,将在收集器中初始化权重更新器时调用此方法,然后再回退到默认构造函数。
- 参数:
policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。
- 返回:
- 权重更新器的实例,或者如果策略无法创建实例则为 None。
无法用于创建实例的实例。
- 返回类型:
WeightUpdaterBase | None
- increment_version()¶
增加策略版本。
- init(*args, **kwargs)¶
使用自定义参数初始化权重更新器。
子类可以覆盖此方法以处理自定义初始化。默认情况下,这是一个无操作。
- 参数:
*args – 初始化位置参数
**kwargs – 初始化关键字参数
- property post_hooks: list[Callable[[], NoneType]]¶
注册到权重更新器的后置钩子列表。
- push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)¶
更新策略的权重,或在指定/所有远程工作进程上更新。
- 参数:
policy_or_weights – 获取权重的来源。可以是: - TensorDictModuleBase:将提取其权重的策略模块 - TensorDictBase:包含权重的 TensorDict - dict:常规的包含权重的 dict - None:将尝试使用 _get_server_weights() 从服务器获取权重
worker_ids – 要更新的工作进程的可选列表。
返回:无。
- register_collector(collector: DataCollectorBase)¶
在更新器中注册一个收集器。
注册后,更新器将不再接受另一个收集器。
- 参数:
collector (DataCollectorBase) – 要注册的 collector。
- register_post_hook(hook: Callable[[], None])¶
注册一个后置钩子,在权重更新后调用。
- 参数:
hook (Callable[[], None]) – 要注册的后置钩子。