WeightUpdaterBase¶
- class torchrl.collectors.WeightUpdaterBase[source]¶
用于在远程推理工作者上更新远程策略权重的基类。
权重更新器是权重更新方案的核心部分。
在叶子收集器节点中,它负责将权重发送到策略,这可以像更新 state_dict 一样简单,或者如果使用了推理服务器,则可能更复杂。
在服务器收集器节点中,它负责将权重发送到叶子收集器。
在收集器中,更新器在
update_policy_weights_()
中被调用。此类中的主要方法是
_push_weights()
方法,该方法更新工作者/策略中的策略权重。此方法由push_weights()
调用,后者也调用后置钩子:只有 _push_weights 应该由子类实现。要扩展此类,请实现以下抽象方法
- _get_server_weights (可选):定义如何从服务器检索权重,如果它们未直接传递给
更新器。仅当未直接传递权重(句柄)时,才会调用此方法。
- _sync_weights_with_worker:定义如何与特定工作者同步权重。
此方法必须由子类实现。
- _maybe_map_weights:在分发之前可选地转换服务器权重。
默认情况下,此方法返回未更改的权重。
- all_worker_ids:提供所有工作者标识符的列表。
默认返回 None(没有工作者 ID)。
- from_policy (可选类方法):定义如何从策略创建权重更新器的实例。
如果实现,在初始化收集器中的权重更新器时,将在回退到默认构造函数之前调用此方法。
- 变量:
collector – 权重接收器的收集器(或任何容器)。收集器通过
register_collector()
注册。
- 后置钩子
- register_post_hook:注册一个后置钩子,在权重更新后调用。
后置钩子必须是接受零个参数的可调用对象。在权重更新后调用后置钩子。后置钩子将在与权重更新器相同的进程中调用。后置钩子将按照注册的顺序调用。
另请参阅
- all_worker_ids() list[int] | list[torch.device] | None [source]¶
获取所有工作进程 ID 的列表。
默认返回 None。子类应覆盖以返回实际的工作进程 ID。
- 返回:
工作进程 ID 列表或 None。
- 返回类型:
list[int] | list[torch.device] | None
- property collector: torch.collector.DataCollectorBase¶
接收器的收集器或容器。
如果容器超出范围或未设置,则返回None。
- classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None [source]¶
可选的类方法,用于从策略创建权重更新器实例。
子类可以实现此方法以提供基于策略的自定义初始化逻辑。如果实现,将在收集器中初始化权重更新器时调用此方法,然后再回退到默认构造函数。
- 参数:
policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。
- 返回:
- 权重更新器的实例,或者如果策略无法创建实例则为 None。
无法用于创建实例的实例。
- 返回类型:
WeightUpdaterBase | None
- init(*args, **kwargs)[source]¶
使用自定义参数初始化权重更新器。
子类可以覆盖此方法以处理自定义初始化。默认情况下,这是一个无操作。
- 参数:
*args – 初始化位置参数
**kwargs – 初始化关键字参数
- property post_hooks: list[Callable[[], NoneType]]¶
注册到权重更新器的后置钩子列表。
- push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)[source]¶
更新策略的权重,或在指定/所有远程工作进程上更新。
- 参数:
policy_or_weights – 从中获取权重的源。可以是:- TensorDictModuleBase:一个策略模块,其权重将被提取- TensorDictBase:包含权重的 TensorDict- dict:一个常规的包含权重的 dict- None:将尝试使用 _get_server_weights() 从服务器获取权重
worker_ids – 要更新的工作进程的可选列表。
返回:无。
- register_collector(collector: DataCollectorBase)[source]¶
在更新器中注册一个收集器。
注册后,更新器将不再接受另一个收集器。
- 参数:
collector (DataCollectorBase) – 要注册的 collector。