快捷方式

vLLMUpdaterV2

class torchrl.collectors.llm.vLLMUpdaterV2(vllm_engine: RLvLLMEngine)[源]

使用 RLvLLMEngine 接口的简化 vLLM 权重更新器。

此更新器可与任何实现 RLvLLMEngine 接口的 vLLM 引擎配合使用,自动提取配置并通过引擎自己的方法处理权重更新。

参数:

vllm_engine – 实现 RLvLLMEngine 接口的 vLLM 引擎。

注意

可以通过 torchrl.collectors.llm.vLLMUpdaterv2=True 来创建此类。

all_worker_ids()[源]

返回工作 ID 列表。

property collector: Any | None

接收器的收集器或容器。

如果容器超出范围或未设置,则返回None

property collectors: list[Any] | None

收集器或接收者容器。

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None

可选的类方法,用于从策略创建权重更新器实例。

子类可以实现此方法以提供基于策略的自定义初始化逻辑。如果实现,将在收集器中初始化权重更新器时调用此方法,然后再回退到默认构造函数。

参数:

policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。

返回:

权重更新器的实例,或者如果策略无法创建实例则为 None。

无法用于创建实例的实例。

返回类型:

WeightUpdaterBase | None

classmethod get_model_metadata(model) dict[str, tuple[torch.dtype, torch.Size]][源]

从模型中获取模型元数据。

参数:

model – 具有 state_dict() 方法的模型(例如,TransformersWrapper)

返回:

参数名称到 (dtype, shape) 元组的映射

返回类型:

dict

get_tp_size() int[源]

获取张量并行大小。

increment_version()

增加策略版本。

init(model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None) None[源]

初始化权重更新器。

参数:

model_metadata – 可选的模型元数据。如果未提供,则使用引擎的元数据。

property post_hooks: list[collections.abc.Callable[[], None]]

注册到权重更新器的后置钩子列表。

push_weights(weights: Iterator[tuple[str, torch.Tensor]] | TensorDictBase)[源]

将权重推送到 vLLM 引擎。

参数:

weights – (name, tensor) 对的迭代器或 TensorDictBase

push_weights_from_transformers(transformers_model)[源]

从 transformers 模型推送权重。

参数:

transformers_model – Transformers PreTrainedModel 或 TorchRL 包装器

push_weights_from_transformers_optimized(transformers_model, batch_size=50)[源]

push_weights_from_transformers 的优化版本,支持 GPU 预加载。

此方法提供了多项优化:1. 在传输前将所有权重预加载到 GPU。2. 可选地批处理权重以实现更好的内存管理。3. 在可能的情况下使用非阻塞传输。

参数:
  • transformers_model – Transformers PreTrainedModel 或 TorchRL 包装器

  • batch_size – 每次传输的权重数量(0 = 不分批)

register_collector(collector)[源]

注册一个收集器并设置策略版本增量后置钩子。

参数:

collector – 要注册的收集器(DataCollectorBase)

register_post_hook(hook: Callable[[], None])

注册一个后置钩子,在权重更新后调用。

参数:

hook (Callable[[], None]) – 要注册的后置钩子。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源