RPCWeightUpdater¶
- class torchrl.collectors.distributed.RPCWeightUpdater(collector_infos, collector_class, collector_rrefs, policy_weights: TensorDictBase, num_workers: int)[源代码]¶
一个远程权重更新器,用于使用 RPC 在远程工作者之间同步策略权重。
The RPCWeightUpdater class provides a mechanism for updating the weights of a policy across remote inference workers using RPC. It is designed to work with the
RPCDataCollector
to ensure that each worker receives the latest policy weights. This class is typically used in distributed data collection scenarios where remote workers are managed via RPC and need to be kept in sync with the central policy weights。- 参数:
collector_infos – 关于 collector 的信息,用于 RPC 通信。
collector_class – 所使用的 collector 的类。
collector_rrefs – collector 的远程引用。
policy_weights (TensorDictBase) – 需要分发给工作者的策略的当前权重。
num_workers (int) – 将接收更新策略权重的远程工作者的数量。
- update_weights()¶
使用 RPC 更新指定或所有远程工作者的权重。
注意
该类假定服务器权重可以直接应用于远程工作者,而无需任何额外的处理。如果您的用例需要更复杂的权重映射或同步逻辑,请考虑扩展 WeightUpdaterBase 并实现自定义实现。
另请参阅
- all_worker_ids() list[int] | list[torch.device] [源代码]¶
获取所有工作进程 ID 的列表。
默认返回 None。子类应覆盖以返回实际的工作进程 ID。
- 返回:
工作进程 ID 列表或 None。
- 返回类型:
list[int] | list[torch.device] | None
- property collector: Any | None¶
接收器的收集器或容器。
如果容器超出范围或未设置,则返回None。
- property collectors: list[Any] | None¶
收集器或接收者容器。
- classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None ¶
可选的类方法,用于从策略创建权重更新器实例。
子类可以实现此方法以提供基于策略的自定义初始化逻辑。如果实现,将在收集器中初始化权重更新器时调用此方法,然后再回退到默认构造函数。
- 参数:
policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。
- 返回:
- 权重更新器的实例,或者如果策略无法创建实例则为 None。
无法用于创建实例的实例。
- 返回类型:
WeightUpdaterBase | None
- increment_version()¶
增加策略版本。
- init(*args, **kwargs)¶
使用自定义参数初始化权重更新器。
子类可以覆盖此方法以处理自定义初始化。默认情况下,这是一个无操作。
- 参数:
*args – 初始化位置参数
**kwargs – 初始化关键字参数
- property post_hooks: list[collections.abc.Callable[[], None]]¶
注册到权重更新器的后置钩子列表。
- push_weights(weights: TensorDictBase | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, **kwargs)[源代码]¶
更新策略的权重,或在指定/所有远程工作进程上更新。
- 参数:
policy_or_weights – 从中获取权重的来源。可以是: - TensorDictModuleBase:将提取权重的策略模块 - TensorDictBase:包含权重的 TensorDict - dict:一个包含权重的普通字典 - None:将尝试使用 _get_server_weights() 从服务器获取权重。
worker_ids – 要更新的工作进程的可选列表。
返回:无。
- register_collector(collector)¶
在更新器中注册一个收集器。
注册后,更新器将不再接受另一个收集器。
- 参数:
collector (DataCollectorBase) – 要注册的 collector。
- register_post_hook(hook: Callable[[], None])¶
注册一个后置钩子,在权重更新后调用。
- 参数:
hook (Callable[[], None]) – 要注册的后置钩子。