快捷方式

WeightUpdaterBase

class torchrl.collectors.WeightUpdaterBase[来源]

用于在推理工作器上更新远程策略权重的基类。

权重更新器是权重更新方案的核心部分

  • 在叶收集器节点中,它负责将权重发送到策略,这可以很简单地更新 state-dict,也可以更复杂,如果使用了推理服务器。

  • 在服务器收集器节点中,它负责将权重发送到叶收集器。

在收集器中,更新器在 update_policy_weights_() 中被调用。

此类的主方法是 _push_weights() 方法,它更新工作器/策略中的策略权重。此方法由 push_weights() 调用,该方法也调用后置钩子:只有 _push_weights 应由子类实现。

要扩展此类,请实现以下抽象方法

  • _get_server_weights (可选):定义如何从服务器检索权重,如果它们未传递给

    更新器。仅当权重(句柄)未直接传递时,才调用此方法。

  • _sync_weights_with_worker:定义如何与特定工作器同步权重。

    此方法必须由子类实现。

  • _maybe_map_weights:可选地在分发之前转换服务器权重。

    默认情况下,此方法返回的权重不变。

  • all_worker_ids:提供所有工作器标识符的列表。

    默认返回 None(无工作器 ID)。

  • from_policy(可选类方法):定义如何从策略创建权重更新器的实例。

    如果实现,在初始化收集器中的权重更新器时,将先调用此方法,然后再回退到默认构造函数。

变量:

collector – 权重接收器的收集器(或任何容器)。收集器通过 register_collector() 注册。

push_weights()[来源]

在指定的或所有远程工作器上更新权重。__call__ 方法是 push_weights 的代理。

register_collector()[来源]

通过弱引用在接收器中注册收集器(或任何容器)。当更新器注册时,收集器将自动调用此方法。

from_policy()[来源]

可选的类方法,用于从策略创建实例。

后置钩子
  • register_post_hook:注册一个后置钩子,在权重更新后被调用。

    后置钩子必须是一个不接受参数的可调用对象。后置钩子将在权重更新后被调用。后置钩子将在与权重更新器相同的进程中调用。后置钩子将按照注册后置钩子的顺序调用。

另请参阅

update_policy_weights_().

all_worker_ids() list[int] | list[torch.device] | None[来源]

获取所有工作进程 ID 的列表。

默认返回 None。子类应覆盖以返回实际的工作进程 ID。

返回:

工作进程 ID 列表或 None。

返回类型:

list[int] | list[torch.device] | None

property collector: Any | None

接收器的收集器或容器。

如果容器超出范围或未设置,则返回None

property collectors: list[Any] | None

收集器或接收者容器。

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None[来源]

可选的类方法,用于从策略创建权重更新器实例。

子类可以实现此方法以提供基于策略的自定义初始化逻辑。如果实现,将在收集器中初始化权重更新器时调用此方法,然后再回退到默认构造函数。

参数:

policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。

返回:

权重更新器的实例,或者如果策略无法创建实例则为 None。

无法用于创建实例的实例。

返回类型:

WeightUpdaterBase | None

increment_version()[来源]

增加策略版本。

init(*args, **kwargs)[来源]

使用自定义参数初始化权重更新器。

子类可以覆盖此方法以处理自定义初始化。默认情况下,这是一个无操作。

参数:
  • *args – 初始化位置参数

  • **kwargs – 初始化关键字参数

property post_hooks: list[collections.abc.Callable[[], None]]

注册到权重更新器的后置钩子列表。

push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)[来源]

更新策略的权重,或在指定/所有远程工作进程上更新。

参数:
  • policy_or_weights – 从中获取权重的来源。可以是: - TensorDictModuleBase:将提取权重的策略模块 - TensorDictBase:包含权重的 TensorDict - dict:一个包含权重的普通字典 - None:将尝试使用 _get_server_weights() 从服务器获取权重。

  • worker_ids – 要更新的工作进程的可选列表。

返回:无。

register_collector(collector)[来源]

在更新器中注册一个收集器。

注册后,更新器将不再接受另一个收集器。

参数:

collector (DataCollectorBase) – 要注册的 collector。

register_post_hook(hook: Callable[[], None])[来源]

注册一个后置钩子,在权重更新后调用。

参数:

hook (Callable[[], None]) – 要注册的后置钩子。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源