快捷方式

RemoteTransformersWrapper

class torchrl.modules.llm.RemoteTransformersWrapper(model, max_concurrency: int = 16, validate_model: bool = True, actor_name: Optional[str] = None, num_gpus: int = 1, num_cpus: int = 1, **kwargs)[源代码]

一个远程 Ray actor 包装器,用于 TransformersWrapper,提供了一个简化的接口。

此类将 TransformersWrapper 实例包装为 Ray actor,允许远程执行,同时提供了一个不需要显式 remote()get() 调用的干净接口。

参数:
  • model (str) – 要包装的 Hugging Face Transformers 模型。必须是一个字符串(模型名称或路径),它将被传递给 transformers.AutoModelForCausalLM.from_pretrained。Transformers 模型不可序列化,因此仅支持模型名称/路径。

  • max_concurrency (int, optional) – 到远程 actor 的并发调用最大数量。默认为 16。

  • validate_model (bool, optional) – 是否验证模型。默认为 True。

  • num_gpus (int, optional) – 要使用的 GPU 数量。默认为 0。

  • num_cpus (int, optional) – 要使用的 CPU 数量。默认为 0。

  • **kwargs – 所有其他参数将直接传递给 TransformersWrapper。

示例

>>> import ray
>>> from torchrl.modules.llm.policies import RemoteTransformersWrapper
>>>
>>> # Initialize Ray if not already done
>>> if not ray.is_initialized():
...     ray.init()
>>>
>>> # Create remote wrapper
>>> remote_wrapper = RemoteTransformersWrapper(
...     model="gpt2",
...     input_mode="history",
...     generate=True,
...     generate_kwargs={"max_new_tokens": 50}
... )
>>>
>>> # Use like a regular wrapper (no remote/get calls needed)
>>> result = remote_wrapper(tensordict_input)
>>> print(result["text"].response)
property batching

批处理是否已启用。

cleanup_batching()[源代码]

清理批处理资源。

property collector

与模块关联的 collector。

property device

用于计算的设备。

property dist_params_keys

分布参数的键。

property dist_sample_keys

分布样本的键。

property generate

文本生成是否启用。

get_batching_state()[源代码]

获取当前的批处理状态。

get_dist(tensordict, **kwargs)[源代码]

使用可选的掩码获取分布(从 logits/log-probs)。

get_dist_with_prompt_mask(tensordict, **kwargs)[源代码]

获取仅包含响应 token(排除提示)的分布。

get_new_version(**kwargs)[源代码]

获取具有更改参数的新版本包装器。

property in_keys

输入键。

property inplace

是否使用原地操作。

property layout

输出张量使用的布局。

log_prob(data, **kwargs)[源代码]

计算对数概率。

property log_prob_keys

对数概率的键。

property log_probs_key

对数概率输出的键。

property masks_key

掩码输出的键。

property num_samples

要生成的样本数量。

property out_keys

输出键。

property pad_output

输出序列是否填充。

property text_key

文本输出的键。

property tokens_key

token 输出的键。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源