VanillaWeightUpdater¶
- class torchrl.collectors.VanillaWeightUpdater(*, weight_getter: Callable[[], TensorDictBase] | None = None, policy_weights: TensorDictBase)[source]¶
一个简单的
WeightUpdaterBase
实现,用于更新本地策略的权重。“VanillaWeightSender”类提供了一种通过直接从指定源获取权重来更新本地策略权重的基本机制。它通常用于权重更新逻辑简单且不需要任何复杂映射或转换的场景。
当未提供自定义权重发送器时,此类的 SyncDataCollector 默认使用它。
另请参阅
- 关键字参数:
weight_getter (Callable[[], TensorDictBase], optional) – 一个返回服务器权重的可调用对象。如果未提供,则必须将权重直接传递给
update_weights()
。policy_weights (TensorDictBase) – 一个 TensorDictBase,其中包含要就地更新的策略权重。
- all_worker_ids() list[int] | list[torch.device] | None ¶
获取所有工作进程 ID 的列表。
默认返回 None。子类应覆盖以返回实际的工作进程 ID。
- 返回:
工作进程 ID 列表或 None。
- 返回类型:
list[int] | list[torch.device] | None
- property collector: Any | None¶
接收器的收集器或容器。
如果容器超出范围或未设置,则返回None。
- property collectors: list[Any] | None¶
收集器或接收者容器。
- classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None [source]¶
从策略创建 VanillaWeightUpdater 实例。
此方法创建一个权重更新器,该更新器将直接使用策略的 state dict 来更新其权重。
- 参数:
policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。
- 返回:
- 已配置为更新的权重更新器实例
策略的权重。
- 返回类型:
- increment_version()¶
增加策略版本。
- init(*args, **kwargs)¶
使用自定义参数初始化权重更新器。
子类可以覆盖此方法以处理自定义初始化。默认情况下,这是一个无操作。
- 参数:
*args – 初始化位置参数
**kwargs – 初始化关键字参数
- property post_hooks: list[collections.abc.Callable[[], None]]¶
注册到权重更新器的后置钩子列表。
- push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)¶
更新策略的权重,或在指定/所有远程工作进程上更新。
- 参数:
policy_or_weights – 从中获取权重的来源。可以是: - TensorDictModuleBase:将提取权重的策略模块 - TensorDictBase:包含权重的 TensorDict - dict:一个包含权重的普通字典 - None:将尝试使用 _get_server_weights() 从服务器获取权重。
worker_ids – 要更新的工作进程的可选列表。
返回:无。
- register_collector(collector)¶
在更新器中注册一个收集器。
注册后,更新器将不再接受另一个收集器。
- 参数:
collector (DataCollectorBase) – 要注册的 collector。
- register_post_hook(hook: Callable[[], None])¶
注册一个后置钩子,在权重更新后调用。
- 参数:
hook (Callable[[], None]) – 要注册的后置钩子。