快捷方式

vLLMUpdater

class torchrl.collectors.llm.vLLMUpdater(master_address: str | None = None, master_port: int | None = None, model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None, vllm_tp_size: int | None = None)[source]

一个将权重发送到 vLLM 工作进程的类。

此类负责在训练策略和 vLLM 推理工作进程之间同步权重。它支持本地 vLLM 实例和远程 Ray Actor。

参数:
  • master_address (str, optional) – 分布式训练的主地址。默认为 localhost。

  • master_port (int, optional) – 分布式训练的主端口。如果为 None,将自动分配。

  • model_metadata (dict[str, tuple[torch.dtype, torch.Size]], optional) – 模型元数据,将参数名称映射到它们的 dtype 和 shape。如果未提供,将从策略中提取。

  • vllm_tp_size (int, optional) – vLLM 张量并行大小。默认为 1。

init()[source]

使用模型元数据初始化更新器并初始化组。

_sync_weights_with_worker()[source]

与 vLLM 工作进程同步权重。

_get_server_weights()[source]

未使用 - 权重必须直接传递。

_maybe_map_weights()[source]

无需映射。

all_worker_ids()[source]

返回 [0],因为我们只有一个工作进程。

注意

此类假定策略是一个可以被 vLLM 加载的 Transformer 模型。策略必须有一个 state_dict() 方法,该方法返回模型权重。

all_worker_ids() list[int][source]

返回 [0],因为我们只有一个工作进程。

property collector: torch.collector.DataCollectorBase

接收器的收集器或容器。

如果容器超出范围或未设置,则返回None

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None

可选的类方法,用于从策略创建权重更新器实例。

子类可以实现此方法以提供基于策略的自定义初始化逻辑。如果实现,将在收集器中初始化权重更新器时调用此方法,然后再回退到默认构造函数。

参数:

policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。

返回:

权重更新器的实例,或者如果策略无法创建实例则为 None。

无法用于创建实例的实例。

返回类型:

WeightUpdaterBase | None

classmethod get_model_metadata(model: TensorDictModuleBase) dict[str, tuple[torch.dtype, torch.Size]][source]

从模型中获取模型元数据。

参数:

model (TensorDictModuleBase) – 要从中获取元数据的模型。必须是 TransformersWrapper 或等效模型。

返回:

模型元数据。

返回类型:

dict[str, tuple[torch.dtype, torch.Size]]

increment_version()

增加策略版本。

init(model_metadata: dict[str, tuple[torch.dtype, torch.Size]]) None[source]

使用模型元数据初始化更新器并初始化组。

参数:

model_metadata (dict[str, tuple[torch.dtype, torch.Size]]) – 模型元数据,将参数名称映射到它们的 dtype 和 shape。

property post_hooks: list[Callable[[], NoneType]]

注册到权重更新器的后置钩子列表。

push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)

更新策略的权重,或在指定/所有远程工作进程上更新。

参数:
  • policy_or_weights – 获取权重的来源。可以是: - TensorDictModuleBase:一个策略模块,其权重将被提取 - TensorDictBase:一个包含权重的 TensorDict - dict:一个常规的 dict,包含权重 - None:将尝试使用 _get_server_weights() 从服务器获取权重

  • worker_ids – 要更新的工作进程的可选列表。

返回:无。

register_collector(collector: DataCollectorBase)[source]

在更新器中注册一个收集器。

注册后,更新器将不再接受另一个收集器。

参数:

collector (DataCollectorBase) – 要注册的 collector。

register_post_hook(hook: Callable[[], None])

注册一个后置钩子,在权重更新后调用。

参数:

hook (Callable[[], None]) – 要注册的后置钩子。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源