快捷方式

RayWeightUpdater

class torchrl.collectors.RayWeightUpdater(policy_weights: TensorDictBase, remote_collectors: list, max_interval: int = 0)[source]

一个远程权重更新器,用于使用 Ray 在远程工作者之间同步策略权重。

RayWeightUpdater 类提供了一种机制,用于在 Ray 管理的远程推理工作者之间更新策略权重。它利用 Ray 的分布式计算能力,将策略权重高效地分发到远程收集器。此类通常用于分布式数据收集器,其中每个远程工作者都需要策略权重的最新副本。

参数:
  • policy_weights (TensorDictBase) – 需要分发到远程工作者的当前策略权重。

  • remote_collectors (List) – 将接收更新策略权重的远程收集器列表。

  • max_interval (int, optional) – 每个工作者进行权重更新之间的最大批次数。默认为 0,表示每批更新一次权重。

all_worker_ids()[source]

返回所有工作者标识符(远程收集器的索引)的列表。

_get_server_weights()[source]

从服务器检索最新权重并将其存储在 Ray 的对象存储中。

_maybe_map_weights()[source]

在分发前可选地映射服务器权重(在此实现中无操作)。

_sync_weights_with_worker()[source]

使用 Ray 将服务器权重与特定远程工作者同步。

_skip_update()[source]

根据间隔确定是否跳过特定工作者的权重更新。

注意

此类假定服务器权重可以直接应用于远程工作者,无需任何额外处理。如果您的用例需要更复杂的权重映射或同步逻辑,请考虑通过自定义实现来扩展 WeightUpdaterBase

另请参阅

WeightUpdaterBase and RayCollector.

all_worker_ids() list[int] | list[torch.device][source]

获取所有工作进程 ID 的列表。

默认返回 None。子类应覆盖以返回实际的工作进程 ID。

返回:

工作进程 ID 列表或 None。

返回类型:

list[int] | list[torch.device] | None

property collector: torch.collector.DataCollectorBase

接收器的收集器或容器。

如果容器超出范围或未设置,则返回None

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None

可选的类方法,用于从策略创建权重更新器实例。

子类可以实现此方法以提供基于策略的自定义初始化逻辑。如果实现,将在收集器中初始化权重更新器时调用此方法,然后再回退到默认构造函数。

参数:

policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。

返回:

权重更新器的实例,或者如果策略无法创建实例则为 None。

无法用于创建实例的实例。

返回类型:

WeightUpdaterBase | None

increment_version()

增加策略版本。

init(*args, **kwargs)

使用自定义参数初始化权重更新器。

子类可以覆盖此方法以处理自定义初始化。默认情况下,这是一个无操作。

参数:
  • *args – 初始化位置参数

  • **kwargs – 初始化关键字参数

property post_hooks: list[Callable[[], NoneType]]

注册到权重更新器的后置钩子列表。

push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)

更新策略的权重,或在指定/所有远程工作进程上更新。

参数:
  • policy_or_weights – 从中获取权重的源。可以是: - TensorDictModuleBase:一个将提取其权重的策略模块 - TensorDictBase:一个包含权重的 TensorDict - dict:一个包含权重的常规 dict - None:将尝试通过 _get_server_weights() 从服务器获取权重

  • worker_ids – 要更新的工作进程的可选列表。

返回:无。

register_collector(collector: DataCollectorBase)

在更新器中注册一个收集器。

注册后,更新器将不再接受另一个收集器。

参数:

collector (DataCollectorBase) – 要注册的 collector。

register_post_hook(hook: Callable[[], None])

注册一个后置钩子,在权重更新后调用。

参数:

hook (Callable[[], None]) – 要注册的后置钩子。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源