vLLMUpdaterV2¶
- class torchrl.collectors.llm.vLLMUpdaterV2(vllm_engine: RLvLLMEngine)[源]¶
使用 RLvLLMEngine 接口的简化 vLLM 权重更新器。
此更新器可与任何实现 RLvLLMEngine 接口的 vLLM 引擎配合使用,自动提取配置并通过引擎自己的方法处理权重更新。
- 参数:
vllm_engine – 实现 RLvLLMEngine 接口的 vLLM 引擎。
注意
可以通过
torchrl.collectors.llm.vLLMUpdater
和 v2=True 来创建此类。- property collector: Any | None¶
接收器的收集器或容器。
如果容器超出范围或未设置,则返回None。
- property collectors: list[Any] | None¶
收集器或接收者容器。
- classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None ¶
可选的类方法,用于从策略创建权重更新器实例。
子类可以实现此方法以提供基于策略的自定义初始化逻辑。如果实现,将在收集器中初始化权重更新器时调用此方法,然后再回退到默认构造函数。
- 参数:
policy (TensorDictModuleBase) – 要从中创建权重更新器的策略。
- 返回:
- 权重更新器的实例,或者如果策略无法创建实例则为 None。
无法用于创建实例的实例。
- 返回类型:
WeightUpdaterBase | None
- classmethod get_model_metadata(model) dict[str, tuple[torch.dtype, torch.Size]] [源]¶
从模型中获取模型元数据。
- 参数:
model – 具有 state_dict() 方法的模型(例如,TransformersWrapper)
- 返回:
参数名称到 (dtype, shape) 元组的映射
- 返回类型:
dict
- increment_version()¶
增加策略版本。
- init(model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None) None [源]¶
初始化权重更新器。
- 参数:
model_metadata – 可选的模型元数据。如果未提供,则使用引擎的元数据。
- property post_hooks: list[collections.abc.Callable[[], None]]¶
注册到权重更新器的后置钩子列表。
- push_weights(weights: Iterator[tuple[str, torch.Tensor]] | TensorDictBase)[源]¶
将权重推送到 vLLM 引擎。
- 参数:
weights – (name, tensor) 对的迭代器或 TensorDictBase
- push_weights_from_transformers(transformers_model)[源]¶
从 transformers 模型推送权重。
- 参数:
transformers_model – Transformers PreTrainedModel 或 TorchRL 包装器
- push_weights_from_transformers_optimized(transformers_model, batch_size=50)[源]¶
push_weights_from_transformers 的优化版本,支持 GPU 预加载。
此方法提供了多项优化:1. 在传输前将所有权重预加载到 GPU。2. 可选地批处理权重以实现更好的内存管理。3. 在可能的情况下使用非阻塞传输。
- 参数:
transformers_model – Transformers PreTrainedModel 或 TorchRL 包装器
batch_size – 每次传输的权重数量(0 = 不分批)
- register_post_hook(hook: Callable[[], None])¶
注册一个后置钩子,在权重更新后调用。
- 参数:
hook (Callable[[], None]) – 要注册的后置钩子。