RemoteTransformersWrapper¶
- class torchrl.modules.llm.RemoteTransformersWrapper(model, max_concurrency: int = 16, validate_model: bool = True, actor_name: Optional[str] = None, num_gpus: int = 1, num_cpus: int = 1, **kwargs)[源代码]¶
一个远程 Ray actor 包装器,用于 TransformersWrapper,提供了一个简化的接口。
此类将 TransformersWrapper 实例包装为 Ray actor,允许远程执行,同时提供了一个不需要显式 remote() 和 get() 调用的干净接口。
- 参数:
model (str) – 要包装的 Hugging Face Transformers 模型。必须是一个字符串(模型名称或路径),它将被传递给 transformers.AutoModelForCausalLM.from_pretrained。Transformers 模型不可序列化,因此仅支持模型名称/路径。
max_concurrency (int, optional) – 到远程 actor 的并发调用最大数量。默认为 16。
validate_model (bool, optional) – 是否验证模型。默认为 True。
num_gpus (int, optional) – 要使用的 GPU 数量。默认为 0。
num_cpus (int, optional) – 要使用的 CPU 数量。默认为 0。
**kwargs – 所有其他参数将直接传递给 TransformersWrapper。
示例
>>> import ray >>> from torchrl.modules.llm.policies import RemoteTransformersWrapper >>> >>> # Initialize Ray if not already done >>> if not ray.is_initialized(): ... ray.init() >>> >>> # Create remote wrapper >>> remote_wrapper = RemoteTransformersWrapper( ... model="gpt2", ... input_mode="history", ... generate=True, ... generate_kwargs={"max_new_tokens": 50} ... ) >>> >>> # Use like a regular wrapper (no remote/get calls needed) >>> result = remote_wrapper(tensordict_input) >>> print(result["text"].response)
- property batching¶
批处理是否已启用。
- property collector¶
与模块关联的 collector。
- property device¶
用于计算的设备。
- property dist_params_keys¶
分布参数的键。
- property dist_sample_keys¶
分布样本的键。
- property generate¶
文本生成是否启用。
- property in_keys¶
输入键。
- property inplace¶
是否使用原地操作。
- property layout¶
输出张量使用的布局。
- property log_prob_keys¶
对数概率的键。
- property log_probs_key¶
对数概率输出的键。
- property masks_key¶
掩码输出的键。
- property num_samples¶
要生成的样本数量。
- property out_keys¶
输出键。
- property pad_output¶
输出序列是否填充。
- property text_key¶
文本输出的键。
- property tokens_key¶
token 输出的键。