快捷方式

RPCWeightUpdater

class torchrl.collectors.distributed.RPCWeightUpdater(collector_infos, collector_class, collector_rrefs, policy_weights: TensorDictBase, num_workers: int)[源代码]

一个远程权重更新器,用于使用 RPC 在远程工作者之间同步策略权重。

The RPCWeightUpdater class provides a mechanism for updating the weights of a policy across remote inference workers using RPC. It is designed to work with the RPCDataCollector to ensure that each worker receives the latest policy weights. This class is typically used in distributed data collection scenarios where remote workers are managed via RPC and need to be kept in sync with the central policy weights。

参数:
  • collector_infos – 关于 collector 的信息,用于 RPC 通信。

  • collector_class – 所使用的 collector 的类。

  • collector_rrefs – collector 的远程引用。

  • policy_weights (TensorDictBase) – 需要分发给工作者的策略的当前权重。

  • num_workers (int) – 将接收更新策略权重的远程工作者的数量。

update_weights()

使用 RPC 更新指定或所有远程工作者的权重。

all_worker_ids()[源代码]

返回所有工作者标识符的列表(在此类中未实现)。

_sync_weights_with_worker()[源代码]

将服务器权重与特定工作者同步(未实现)。

_get_server_weights()[源代码]

从服务器检索最新权重(未实现)。

_maybe_map_weights()[源代码]

分发前可选地映射服务器权重(未实现)。

注意

该类假定服务器权重可以直接应用于远程工作者,而无需任何额外的处理。如果您的用例需要更复杂的权重映射或同步逻辑,请考虑扩展 WeightUpdaterBase 并实现自定义实现。

另请参阅

WeightUpdaterBaseRPCDataCollector

all_worker_ids() list[int] | list[torch.device][源代码]

获取所有工作进程 ID 的列表。

默认返回 None。子类应覆盖以返回实际的工作进程 ID。

返回:

工作进程 ID 列表或 None。

返回类型:

list[int] | list[torch.device] | None

property collector: Any | None

接收器的收集器或容器。

如果容器超出范围或未设置,则返回None

property collectors: list[Any] | None

收集器或接收者容器。

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None

可选的类方法,用于从策略创建权重更新器实例。

子类可以实现此方法以提供基于策略的自定义初始化逻辑。如果实现,将在收集器中初始化权重更新器时调用此方法,然后再回退到默认构造函数。

参数:

policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。

返回:

权重更新器的实例,或者如果策略无法创建实例则为 None。

无法用于创建实例的实例。

返回类型:

WeightUpdaterBase | None

increment_version()

增加策略版本。

init(*args, **kwargs)

使用自定义参数初始化权重更新器。

子类可以覆盖此方法以处理自定义初始化。默认情况下,这是一个无操作。

参数:
  • *args – 初始化位置参数

  • **kwargs – 初始化关键字参数

property post_hooks: list[collections.abc.Callable[[], None]]

注册到权重更新器的后置钩子列表。

push_weights(weights: TensorDictBase | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, **kwargs)[源代码]

更新策略的权重,或在指定/所有远程工作进程上更新。

参数:
  • policy_or_weights – 从中获取权重的来源。可以是: - TensorDictModuleBase:将提取权重的策略模块 - TensorDictBase:包含权重的 TensorDict - dict:一个包含权重的普通字典 - None:将尝试使用 _get_server_weights() 从服务器获取权重。

  • worker_ids – 要更新的工作进程的可选列表。

返回:无。

register_collector(collector)

在更新器中注册一个收集器。

注册后,更新器将不再接受另一个收集器。

参数:

collector (DataCollectorBase) – 要注册的 collector。

register_post_hook(hook: Callable[[], None])

注册一个后置钩子,在权重更新后调用。

参数:

hook (Callable[[], None]) – 要注册的后置钩子。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源