快捷方式

VanillaWeightUpdater

class torchrl.collectors.VanillaWeightUpdater(*, weight_getter: Callable[[], TensorDictBase] | None = None, policy_weights: TensorDictBase)[source]

A simple implementation of WeightUpdaterBase for updating local policy weights.

The VanillaWeightSender class provides a basic mechanism for updating the weights of a local policy by directly fetching them from a specified source. It is typically used in scenarios where the weight update logic is straightforward and does not require any complex mapping or transformation.

This class is used by default in the SyncDataCollector when no custom weight sender is provided.

关键字参数:
  • weight_getter (Callable[[], TensorDictBase], optional) – a callable that returns the weights from the server. If not provided, the weights must be passed to update_weights() directly.

  • policy_weights (TensorDictBase) – a TensorDictBase containing the policy weights to be updated in-place.

all_worker_ids() list[int] | list[torch.device] | None

获取所有工作进程 ID 的列表。

默认返回 None。子类应覆盖以返回实际的工作进程 ID。

返回:

工作进程 ID 列表或 None。

返回类型:

list[int] | list[torch.device] | None

property collector: torch.collector.DataCollectorBase

接收器的收集器或容器。

如果容器超出范围或未设置,则返回None

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None[source]

Creates a VanillaWeightUpdater instance from a policy.

This method creates a weight updater that will update the policy’s weights directly using its state dict.

参数:

policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。

返回:

An instance of the weight updater configured to update

the policy’s weights.

返回类型:

VanillaWeightUpdater

increment_version()

增加策略版本。

init(*args, **kwargs)

使用自定义参数初始化权重更新器。

子类可以覆盖此方法以处理自定义初始化。默认情况下,这是一个无操作。

参数:
  • *args – 初始化位置参数

  • **kwargs – 初始化关键字参数

property post_hooks: list[Callable[[], NoneType]]

注册到权重更新器的后置钩子列表。

push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)

更新策略的权重,或在指定/所有远程工作进程上更新。

参数:
  • policy_or_weights – The source to get weights from. Can be: - TensorDictModuleBase: A policy module whose weights will be extracted - TensorDictBase: A TensorDict containing weights - dict: A regular dict containing weights - None: Will try to get weights from server using _get_server_weights()

  • worker_ids – 要更新的工作进程的可选列表。

返回:无。

register_collector(collector: DataCollectorBase)

在更新器中注册一个收集器。

注册后,更新器将不再接受另一个收集器。

参数:

collector (DataCollectorBase) – 要注册的 collector。

register_post_hook(hook: Callable[[], None])

注册一个后置钩子,在权重更新后调用。

参数:

hook (Callable[[], None]) – 要注册的后置钩子。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源