vLLMUpdater¶
- class torchrl.collectors.llm.vLLMUpdater(*args, v2=False, **kwargs)[源代码]¶
将权重发送到 vLLM 工作节点的类。
此类负责在训练策略和 vLLM 推理工作节点之间同步权重。它支持本地 vLLM 实例和远程 Ray Actor。
- 参数:
master_address (str, optional) – 分布式训练的主地址。默认为 localhost。
master_port (int, optional) – 分布式训练的主端口。如果为 None,则会自动分配。
model_metadata (dict[str, tuple[torch.dtype, torch.Size]], optional) – 模型元数据,将参数名称映射到它们的 dtype 和 shape。如果未提供,将从策略中提取。
vllm_tp_size (int, optional) – vLLM 的张量并行大小。默认为 1。
v2 (bool, optional) – 如果为 True,则返回 vLLMUpdaterV2 实例。这是一个实验性功能,提供了与 AsyncVLLM 引擎更好的集成。使用 v2=True 时,必须提供 vllm_engine 参数而不是上述参数。默认为 False。
注意
此类假定策略是一个可由 vLLM 加载的 Transformers 模型。策略必须具有 state_dict() 方法,该方法返回模型权重。
警告
v2=True 选项是实验性的,在未来的版本中可能会有向后不兼容的更改。但是,它通常被认为是与 AsyncVLLM 引擎配合使用的更好选择,并提供更好的性能和可靠性。
- property collector: Any | None¶
接收器的收集器或容器。
如果容器超出范围或未设置,则返回None。
- property collectors: list[Any] | None¶
收集器或接收者容器。
- classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None ¶
可选的类方法,用于从策略创建权重更新器实例。
子类可以实现此方法以提供基于策略的自定义初始化逻辑。如果实现,将在收集器中初始化权重更新器时调用此方法,然后再回退到默认构造函数。
- 参数:
policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。
- 返回:
- 权重更新器的实例,或者如果策略无法创建实例则为 None。
无法用于创建实例的实例。
- 返回类型:
WeightUpdaterBase | None
- classmethod get_model_metadata(model: TensorDictModuleBase) dict[str, tuple[torch.dtype, torch.Size]] [源代码]¶
从模型中获取模型元数据。
- 参数:
model (TensorDictModuleBase) – 要从中获取元数据的模型。必须是 TransformersWrapper 或同等模型。
- 返回:
模型元数据。
- 返回类型:
dict[str, tuple[torch.dtype, torch.Size]]
- increment_version()¶
增加策略版本。
- init(model_metadata: dict[str, tuple[torch.dtype, torch.Size]]) None [源代码]¶
使用模型元数据初始化更新器并初始化组。
- 参数:
model_metadata (dict[str, tuple[torch.dtype, torch.Size]]) – 模型元数据,将参数名称映射到它们的 dtype 和 shape。
- property post_hooks: list[collections.abc.Callable[[], None]]¶
注册到权重更新器的后置钩子列表。
- push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)¶
更新策略的权重,或在指定/所有远程工作进程上更新。
- 参数:
policy_or_weights – 从中获取权重的来源。可以是: - TensorDictModuleBase:将提取权重的策略模块 - TensorDictBase:包含权重的 TensorDict - dict:一个包含权重的普通字典 - None:将尝试使用 _get_server_weights() 从服务器获取权重。
worker_ids – 要更新的工作进程的可选列表。
返回:无。
- register_collector(collector: DataCollectorBase)[源代码]¶
在更新器中注册一个收集器。
注册后,更新器将不再接受另一个收集器。
- 参数:
collector (DataCollectorBase) – 要注册的 collector。
- register_post_hook(hook: Callable[[], None])¶
注册一个后置钩子,在权重更新后调用。
- 参数:
hook (Callable[[], None]) – 要注册的后置钩子。