VanillaWeightUpdater¶
- class torchrl.collectors.VanillaWeightUpdater(*, weight_getter: Callable[[], TensorDictBase] | None = None, policy_weights: TensorDictBase)[source]¶
A simple implementation of
WeightUpdaterBase
for updating local policy weights.The VanillaWeightSender class provides a basic mechanism for updating the weights of a local policy by directly fetching them from a specified source. It is typically used in scenarios where the weight update logic is straightforward and does not require any complex mapping or transformation.
This class is used by default in the SyncDataCollector when no custom weight sender is provided.
另请参阅
- 关键字参数:
weight_getter (Callable[[], TensorDictBase], optional) – a callable that returns the weights from the server. If not provided, the weights must be passed to
update_weights()
directly.policy_weights (TensorDictBase) – a TensorDictBase containing the policy weights to be updated in-place.
- all_worker_ids() list[int] | list[torch.device] | None ¶
获取所有工作进程 ID 的列表。
默认返回 None。子类应覆盖以返回实际的工作进程 ID。
- 返回:
工作进程 ID 列表或 None。
- 返回类型:
list[int] | list[torch.device] | None
- property collector: torch.collector.DataCollectorBase¶
接收器的收集器或容器。
如果容器超出范围或未设置,则返回None。
- classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None [source]¶
Creates a VanillaWeightUpdater instance from a policy.
This method creates a weight updater that will update the policy’s weights directly using its state dict.
- 参数:
policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。
- 返回:
- An instance of the weight updater configured to update
the policy’s weights.
- 返回类型:
- increment_version()¶
增加策略版本。
- init(*args, **kwargs)¶
使用自定义参数初始化权重更新器。
子类可以覆盖此方法以处理自定义初始化。默认情况下,这是一个无操作。
- 参数:
*args – 初始化位置参数
**kwargs – 初始化关键字参数
- property post_hooks: list[Callable[[], NoneType]]¶
注册到权重更新器的后置钩子列表。
- push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)¶
更新策略的权重,或在指定/所有远程工作进程上更新。
- 参数:
policy_or_weights – The source to get weights from. Can be: - TensorDictModuleBase: A policy module whose weights will be extracted - TensorDictBase: A TensorDict containing weights - dict: A regular dict containing weights - None: Will try to get weights from server using _get_server_weights()
worker_ids – 要更新的工作进程的可选列表。
返回:无。
- register_collector(collector: DataCollectorBase)¶
在更新器中注册一个收集器。
注册后,更新器将不再接受另一个收集器。
- 参数:
collector (DataCollectorBase) – 要注册的 collector。
- register_post_hook(hook: Callable[[], None])¶
注册一个后置钩子,在权重更新后调用。
- 参数:
hook (Callable[[], None]) – 要注册的后置钩子。