快捷方式

LLM 接口

TorchRL 为 LLM 的训练后和微调提供了全面的框架。LLM API 基于五个核心概念构建,这些概念共同作用,为语言模型创建完整的强化学习管道。

  1. 数据表示 (数据结构):用于处理对话、文本解析和 LLM 输出类。这包括用于管理对话上下文的 History 类,以及用于 token、对数概率和文本的结构化输出类。

  2. LLM 包装器 API (模块):用于不同 LLM 后端的统一接口,包括用于 Hugging Face 模型的 TransformersWrapper,用于 vLLM 推理的 vLLMWrapper,以及用于高性能分布式 vLLM 推理(推荐)的 AsyncVLLM。这些包装器在不同后端之间提供了一致的输入/输出格式,并为损失计算、数据存储、评分、权重同步等提供了一个集成的接口。

  3. 环境 (环境):负责数据加载、工具执行、奖励计算和格式化的编排层。这包括用于对话管理的 ChatEnv、数据集环境以及用于工具集成的各种转换。

  4. 目标 (目标):用于 LLM 训练的专用损失函数,包括用于组相对策略优化(Group Relative Policy Optimization)的 GRPOLoss 和用于监督微调(supervised fine-tuning)的 SFTLoss

  5. 收集器 (收集器):收集器用于从环境中收集数据并将其存储为可用于训练的格式。这包括用于从环境中收集数据的 LLMCollector 和使用 Ray 在分布式环境中收集数据的 RayLLMCollector

这些组件协同工作,构建一个完整的管道:环境负责加载和格式化数据,LLM 包装器负责推理,数据结构负责维护对话上下文,目标负责计算训练损失。模块化设计允许您根据具体用例混合搭配组件。

sota-implementations/grpo/ 目录中可以找到一个使用 LLM API 的完整示例。训练编排涉及三个主要组件:

  • 数据收集器:持有对环境和推理模型或引擎的引用。它收集数据,放入缓冲区,并处理权重更新。

  • 回放缓冲区:存储收集到的数据并执行任何预处理或后处理步骤。这些可能包括:- 使用基于蒙特卡罗的方法进行优势估计(使用 MCAdvantage 转换);- 对输出进行评分;- 日志记录等。

  • 训练器:处理训练循环,包括优化步骤、序列化、日志记录和权重更新初始化。

警告

LLM API 仍在开发中,未来可能会发生变化。欢迎提供反馈、报告问题和提交 PR!

数据结构

数据表示层为以结构化的方式处理对话和 LLM 输出奠定了基础。

History 类

与 Transformers 中通常存在的聊天格式(请参阅 Hugging Face 聊天文档)相比,History 类是 TensorClass 的版本。它提供了一个全面的 API 来管理对话数据,功能包括:

  • 文本解析和格式化:使用 from_text()apply_chat_template() 在文本和结构化对话格式之间进行转换。

  • 动态对话构建:使用 append()extend() 方法追加和扩展对话。

  • 多模型支持:自动检测不同模型系列(Qwen、DialoGPT、Falcon、DeepSeek 等)的模板。

  • 助手 token 屏蔽:识别为强化学习应用而由助手生成的 token。

  • 工具调用支持:在对话中处理函数调用和工具响应。

  • 批处理操作:用于同时处理多个对话的高效张量操作。

History(role, content[, is_complete, ...])

ContentBase(type, text, url, data, ...[, ...])

支持的模型系列

我们目前支持以下模型系列进行字符串到 History 的解析或助手 token 屏蔽:

  • Qwen 系列(例如,Qwen/Qwen2.5-0.5B):具有完整工具调用支持的自定义模板。

  • DialoGPT 系列(例如,microsoft/DialoGPT-medium):用于对话格式的自定义模板。

  • Falcon 系列(例如,tiiuae/falcon-7b-instruct):用于指令格式的自定义模板。

  • DeepSeek 系列(例如,deepseek-ai/deepseek-coder-6.7b-base):具有原生格式的自定义模板。

其他模型也得到支持,但您需要为它们提供自定义模板。LLAMA、Mistral、OPT、GPT、MPT、BLOOM、Pythia、Phi 等将使用默认的 chatml_format 模板。

用法

>>> from torchrl.data.llm.chat import History
>>> from transformers import AutoTokenizer
>>>
>>> # Create a conversation history
>>> history = History.from_chats([[
...     {"role": "user", "content": "Hello"},
...     {"role": "assistant", "content": "Hi there!"},
...     {"role": "user", "content": "How are you?"},
...     {"role": "assistant", "content": "I'm doing well, thanks!"}
... ]])
>>>
>>> # Load any supported tokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
>>>
>>> # Apply chat template with assistant token masking
>>> result = history.apply_chat_template(
...     chat_template_name="qwen",
...     add_generation_prompt=False,
...     return_dict=True,
...     return_assistant_tokens_mask=True,
... )
>>>
>>> # The result contains an assistant_masks tensor
>>> assistant_masks = result["assistant_masks"]
>>> print(f"Assistant tokens: {assistant_masks.sum().item()}")

添加自定义模板

您可以使用 torchrl.data.llm.chat.add_chat_template() 函数为新模型系列添加自定义聊天模板。

用法示例

添加 Llama 模板
>>> from torchrl.data.llm.chat import add_chat_template, History
>>> from transformers import AutoTokenizer
>>>
>>> # Define the Llama chat template
>>> llama_template = '''
... {% for message in messages %}
... {%- if message['role'] == 'user' %}
... {{ '<s>[INST] ' + message['content'] + ' [/INST]' }}
... {%- elif message['role'] == 'assistant' %}
... {% generation %}{{ message['content'] + '</s>' }}{% endgeneration %}
... {%- endif %}
... {% endfor %}
... {%- if add_generation_prompt %}
... {% generation %}{{ ' ' }}{% endgeneration %}
... {%- endif %}
... '''
>>>
>>> # Define the inverse parser for Llama format
>>> def parse_llama_text(text: str) -> History:
...     import re
...     pattern = r'<s>\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)</s>'
...     matches = re.findall(pattern, text, re.DOTALL)
...     messages = []
...     for user_content, assistant_content in matches:
...         messages.append(History(role="user", content=user_content.strip()))
...         messages.append(History(role="assistant", content=assistant_content.strip()))
...     return lazy_stack(messages)
>>>
>>> # Add the template with auto-detection
>>> add_chat_template(
...     template_name="llama",
...     template=llama_template,
...     inverse_parser=parse_llama_text,
...     model_family_keywords=["llama", "meta-llama"]
... )
>>>
>>> # Now you can use it with auto-detection
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> history = History.from_chats([[
...     {"role": "user", "content": "Hello"},
...     {"role": "assistant", "content": "Hi there!"}
... ]])
>>>
>>> # Auto-detection will use the llama template
>>> result = history.apply_chat_template(
...     tokenizer=tokenizer,
...     add_generation_prompt=False,
...     return_dict=True,
...     return_assistant_tokens_mask=True,
... )

测试您的自定义模板

在添加自定义模板时,您应该对其进行测试以确保其正常工作。以下是建议的测试:

助手 token 屏蔽测试

测试您的模板是否支持助手 token 屏蔽。

import pytest
from torchrl.data.llm.chat import History, add_chat_template
from transformers import AutoTokenizer

def test_my_model_assistant_masking():
    """Test that your model supports assistant token masking."""
    # Add your template first
    add_chat_template(
        template_name="my_model",
        template="your_template_here",
        model_family_keywords=["my_model"]
    )

    tokenizer = AutoTokenizer.from_pretrained("your/model/name")
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'}
    ]])

    result = history.apply_chat_template(
        tokenizer=tokenizer,
        chat_template_name="my_model",
        add_generation_prompt=False,
        return_dict=True,
        return_assistant_tokens_mask=True,
    )

    # Verify assistant mask is present
    assert 'assistant_masks' in result
    assert result['assistant_masks'].shape[0] == 1, "Should have batch dimension of 1"
    assert result['assistant_masks'].shape[1] > 0, "Should have sequence length > 0"

    # Verify some assistant tokens are masked
    assistant_token_count = result['assistant_masks'].sum().item()
    assert assistant_token_count > 0, "Should have assistant tokens masked"
    print(f"✓ {assistant_token_count} assistant tokens masked")
模板等价性测试

测试您的自定义模板是否产生与模型默认模板相同的输出(不包括屏蔽)。

def test_my_model_template_equivalence():
    """Test that your template matches the model's default template."""
    tokenizer = AutoTokenizer.from_pretrained("your/model/name")
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'},
        {'role': 'user', 'content': 'How are you?'},
        {'role': 'assistant', 'content': 'I\'m good, thanks!'},
    ]])

    # Get output with model's default template
    try:
        default_out = history.apply_chat_template(
            tokenizer=tokenizer,
            add_generation_prompt=False,
            chat_template=tokenizer.chat_template,
            tokenize=False,
        )
    except Exception as e:
        default_out = None
        print(f"[WARN] Could not get default template: {e}")

    # Get output with your custom template
    custom_out = history.apply_chat_template(
        tokenizer=tokenizer,
        add_generation_prompt=False,
        chat_template_name="my_model",
        tokenize=False,
    )

    if default_out is not None:
        # Normalize whitespace for comparison
        import re
        def norm(s):
            return re.sub(r"\s+", " ", s.strip())

        assert norm(default_out) == norm(custom_out), (
            f"Custom template does not match default!\n"
            f"Default: {default_out}\nCustom: {custom_out}"
        )
        print("✓ Template equivalence verified")
    else:
        print("[INFO] Skipped equivalence check (no default template available)")
反向解析测试

如果您提供了反向解析器,请测试其是否正常工作。

def test_my_model_inverse_parsing():
    """Test that your inverse parser works correctly."""
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'}
    ]])

    # Format using your template
    formatted = history.apply_chat_template(
        tokenizer=tokenizer,
        chat_template_name="my_model",
        add_generation_prompt=False,
        tokenize=False,
    )

    # Parse back using your inverse parser
    parsed = History.from_text(formatted, chat_template_name="my_model")

    # Verify the parsing worked
    assert parsed.role == history.role
    assert parsed.content == history.content
    print("✓ Inverse parsing verified")

LLM 包装器 API

LLM 包装器 API 提供了不同 LLM 后端的统一接口,确保了训练和推理管道之间一致的输入/输出格式。主要的包装器是用于 Hugging Face 模型的 TransformersWrapper 和用于 vLLM 推理的 vLLMWrapper

数据结构类

包装器使用结构化的 TensorClass 对象来表示 LLM 数据的不同方面:

  • :class:`~torchrl.modules.llm.policies.Text`:包含具有 promptresponsefull 字段的文本数据。

  • :class:`~torchrl.modules.llm.policies.ChatHistory`:包含具有 promptresponsefull 字段的 History 对象。

  • :class:`~torchrl.modules.llm.policies.Tokens`:包含具有 promptresponsefull 字段的 token 化数据。

  • :class:`~torchrl.modules.llm.policies.LogProbs`:包含具有 promptresponsefull 字段的对数概率。

  • :class:`~torchrl.modules.llm.policies.Masks`:包含注意力掩码和助手掩码。

API 流程

包装器在两种不同的模式下运行:

生成模式 (`generate=True`):- 输入:从 prompt 字段读取(例如,history.prompttext.prompttokens.prompt)- 输出:写入 responsefull 字段。

  • response:仅包含新生成的内容。

  • full:包含完整的序列(prompt + response)。

对数概率模式 (`generate=False`):- 输入:从 full 字段读取(例如,history.fulltext.fulltokens.full)- 输出:将对数概率写入相应的 full 字段。

LLM-环境交互循环

LLM-Environment interaction loop

LLM-环境交互:LLM 生成响应,环境更新对话,转换可以注入新消息或工具。

在典型的 RL 或工具增强设置中,LLM 和环境在一个循环中交互:

  1. LLM 生成:LLM 包装器接收 prompt(当前对话历史),生成 response,并输出一个 full 字段。

包含 prompt 和 response 的连接。

  1. 环境步进:环境将 full 字段作为下一个 prompt 提供给 LLM。这确保了对话

上下文随着每次轮次而增长。更多详细信息请参阅 ref_env_llm_step

  1. 转换:在下一个 LLM 步进之前,转换可以修改对话,例如,通过插入新的用户消息、工具调用

或奖励注释。

  1. 重复:此过程重复进行所需的轮次数,从而实现多轮对话、工具使用和 RL 训练。

这种设计允许在每一步对对话进行灵活的增强,支持高级 RL 和工具使用场景。

典型的伪代码循环

# Get the first prompt out of an initial query
obs = env.reset(TensorDict({"query": "Hello!"}, batch_size=env.batch_size, device=env.device))
while not done:
    # LLM generates a response given the current prompt
    llm_output = llm(obs)
    # Environment steps: creates a ("next", "history") field with the new prompt (from the previous `"full"` field)
    obs = env.step(llm_output)

与 History 集成

当使用 input_mode="history" 时,包装器可以与 History 类无缝集成。

  • 输入:接收包含 prompt 字段中 History 的 ChatHistory 对象。

  • 生成:应用聊天模板将 History 转换为 token,生成响应,然后将完整的文本解析回 History 对象。

  • 输出:返回一个 ChatHistory,其中包含:- prompt:原始对话历史- response:仅包含助手响应的新 History 对象- full:包含新响应已追加的完整对话历史。

此设计允许自然的对话流程,其中每个生成步骤都会扩展对话历史,使其成为多轮对话系统的理想选择。

Prompt 与 Response 及填充

LLM output data format (Tokens, Masks, Padded vs. Sparse)

LLM 输出结构:Token、LogProbs 和 Mask 的填充与稀疏表示。

上图说明了 TorchRL LLM API 中使用的主要输出类的结构:

  • Tokens(以及扩展的 LogProbs):- 填充格式:批次中的所有序列都填充到相同的长度(使用特殊的 pad token),使其适合张量操作。prompt 和 response 被连接起来形成 tokens.full,掩码指示有效位置与填充位置。- 稀疏格式:每个序列保留其原始长度(无填充),表示为张量列表。这对于可变长度数据更节省内存。

  • Masks:显示了两种主要掩码:- mask.attention_mask_all 标记有效(非 pad)token。- mask.assistant_mask_all 标记由助手生成的 token(对于 RLHF 和 SFT 训练有用)。

  • Text:未详细显示,因为它只是 prompt、response 或完整序列的解码字符串表示。

此格式确保所有 LLM 输出(Tokens、LogProbs、Masks、Text)都是一致的且易于操作,无论您使用填充批处理还是稀疏批处理。

总的来说,我们建议使用未填充的数据,因为它更节省内存且易于操作。例如,当从缓冲区收集多个填充元素时,很难清楚地理解如何重新填充它们以将它们组合成一个连贯的批次。使用未填充的数据更直接。

模块

LLM 包装器 API 提供了不同 LLM 后端的统一接口,确保了训练和推理管道之间一致的输入/输出格式。

包装器

这些原语的主要目标是:

  • 统一训练和推理管道的输入/输出数据格式。

  • 统一后端之间的输入/输出数据格式(以便在不同的损失和收集器中使用不同的后端)。

  • 提供适当的工具来在典型的 RL 环境中构建这些对象(资源分配、异步执行、权重更新等)。

LLMWrapperBase(*args, **kwargs)

LLM 包装器基类。

TransformersWrapper(*args, **kwargs)

Hugging Face Transformers 模型的包装器类,为文本生成和对数概率计算提供了一致的接口。

vLLMWrapper(*args, **kwargs)

vLLM 模型的包装器类,为文本生成和对数概率计算提供了一致的接口。

RemoteTransformersWrapper(model[, ...])

一个用于 TransformersWrapper 的远程 Ray actor 包装器,提供了一个简化的接口。

AsyncVLLM(engine_args[, num_replicas, ...])

一个管理多个异步 vLLM 引擎 actor 以进行分布式推理的服务。

ChatHistory([prompt, response, full, ...])

Text([prompt, response, full, device, names])

LogProbs([prompt, response, full, padded, ...])

Masks([all_attention_mask, ...])

Tokens([prompt, response, full, padded, ...])

远程包装器

TorchRL 提供了远程包装器类,可以使用 Ray 实现 LLM 包装器的分布式执行。这些包装器提供了一个简化的接口,不需要显式的 remote()get() 调用,使其易于在分布式环境中使用。

注意

对于 vLLM:请改用 AsyncVLLM

对于基于 vLLM 的推理,我们建议直接使用 AsyncVLLM 而不是远程包装器。AsyncVLLM 提供了更好的性能、资源利用率和内置负载均衡。有关详细信息,请参阅上面的 异步 vLLM 引擎(推荐) 部分。

远程包装器主要用于基于 Transformers 的模型或其他 AsyncVLLM 不适用的用例。

主要特性

  • 简化的接口:无需显式调用 remote()get()

  • 完整的 API 兼容性:公开了 LLMWrapperBase 基类的所有公共方法。

  • 自动 Ray 管理:内部处理 Ray 初始化和远程执行。

  • 属性访问:所有属性都可以通过远程包装器访问。

  • 错误处理:从远程 actor 正确传播错误。

  • 资源管理:支持上下文管理器以进行自动清理。

模型参数要求

  • RemoteTransformersWrapper:仅接受字符串模型名称/路径。Transformers 模型不可序列化。

支持的后端

目前,只有基于 Transformers 的模型通过远程包装器支持。对于 vLLM 模型,请改用 AsyncVLLM

用法示例

import ray
from torchrl.modules.llm.policies import RemoteTransformersWrapper
from torchrl.data.llm import History
from torchrl.modules.llm.policies import ChatHistory, Text
from tensordict import TensorDict

# Initialize Ray (if not already done)
if not ray.is_initialized():
    ray.init()

# Transformers wrapper (only string models supported)
# The remote wrappers implement context managers for proper resource cleanup:
with RemoteTransformersWrapper(
    model="gpt2",
    max_concurrency=16,
    input_mode="text",
    generate=True,
    generate_kwargs={"max_new_tokens": 30}
) as remote_transformers:

    text_input = TensorDict({"text": Text(prompt="Hello world")}, batch_size=(1,))
    result = remote_transformers(text_input)
    print(result["text"].response)

性能考虑

  • 网络开销:远程执行增加了网络通信开销。

  • 序列化:数据在发送到远程 actor 时会被序列化。

  • 内存:每个远程 actor 都维护自己的模型副本。

  • 并发:多个远程包装器可以并发运行。

  • 最大并发数:使用 max_concurrency 参数控制对每个远程 actor 的并发调用次数。

  • 清理:始终使用上下文管理器或调用 cleanup_batching() 以防止因批处理锁而挂起。

Utils

make_async_vllm_engine(model_name[, ...])

创建异步 vLLM 引擎服务。

stateless_init_process_group_async(...)

为分布式通信初始化一个无状态进程组(异步版本)。

make_vllm_worker(*, model_name[, devices, ...])

创建具有张量并行支持的 vLLM 推理引擎。

stateless_init_process_group(master_address, ...)

为分布式通信初始化一个无状态进程组。

收集器

TorchRL 提供专门的收集器类(LLMCollectorRayLLMCollector),这些类针对 LLM 用例进行了定制。我们还为一些推理引擎提供了专用的更新器。

有关收集器 API 的更多详细信息,请参阅 ref_collectors。简而言之,收集器的想法是将管道的推理部分隔离到一个专用类中。收集器通常以策略和环境为输入,并在两者之间交替运行。在“经典”设置中,策略类似于正在训练的策略(具有一些可选的额外探索)。在 LLM 微调的上下文中,策略通常是一个专业的推理引擎,例如 vLLM 服务器。收集器由以下参数和功能定义:

  • 同步/异步:收集器是否应以同步或异步模式运行。在同步模式下,收集器将在优化/训练步骤之间交替运行推理步骤。在异步模式下,收集器将与优化/训练步骤并行运行推理步骤。可以向收集器传递一个回放缓冲区,这样收集器就可以直接写入它。在其他情况下,收集器可以被迭代以收集数据。

  • 步数:收集器构建时具有一定的步数预算,以及在收集期间每次 yield 批次中要包含的步数。

  • 权重更新器:权重更新器是更新策略权重的类。将权重更新隔离到一个专用类中,可以根据策略规范轻松实现不同的权重更新策略。

策略版本跟踪

LLM 收集器还允许跟踪策略的版本,这对于某些用例很有用。这通过将 PolicyVersion 转换添加到环境中来实现,然后由收集器在每次权重更新后递增。为此,可以向收集器构造函数提供转换的状态版本或一个布尔值。

>>> from torchrl.envs.llm.transforms import PolicyVersion
>>> from torchrl.collectors.llm import LLMCollector
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
>>> env = make_env() # place your code here
>>> policy = make_policy() # place your code here
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
>>> # init the updater
>>> collector.weight_updater.init(...)
>>> # the version is incremented after each weight update
>>> collector.update_policy_weights_(state_dict=...)
>>> print(collector.policy_version_tracker.version)
>>> # the policy version is written in the data
>>> for data in collector:
...     print(data["policy_version"])

vLLMUpdater(*args[, v2])

一个将权重发送到 vLLM worker 的类。

vLLMUpdaterV2(vllm_engine)

使用 RLvLLMEngine 接口的简化 vLLM 权重更新器。

LLMCollector(env, *[, policy, ...])

SyncDataCollector 的简化版本,用于 LLM 推理。

RayLLMCollector(env, *[, policy, ...])

LLM Collector 的轻量级 Ray 实现,可以远程扩展和采样。

环境

环境层负责数据加载、工具执行、奖励计算和格式化的编排。当使用 TorchRL 微调 LLM 时,环境是推理管道的关键组成部分,与策略和收集器并列。

ChatEnv

ChatEnv 是 LLM 环境的空白画布——它是一个基本工具,旨在通过添加特定功能的转换来扩展。基础 ChatEnv 提供了使用 History 格式管理对话状态的基本结构,但它故意保持最小化以实现最大的灵活性。

核心功能

ChatEnv 在三种主要模式下运行:- History 模式:使用 History 对象进行对话管理。- Text 模式:使用简单的文本字符串进行输入/输出。- Tokens 模式:使用 token 化数据进行输入/输出。

环境通过以下方式维护对话状态:- 重置:初始化带有可选系统 prompt 的新对话。- 步进:接收 LLM 的响应并更新对话历史,准备下一个 prompt。

基于转换的架构

转换是扩展 ChatEnv 以实现特定功能的​​主要方式:

与 LLM 包装器集成

ChatEnv 设计为与 TransformersWrappervLLMWrapper 无缝协同工作。环境负责管理对话状态,而包装器负责实际的 LLM 推理,实现了清晰的关注点分离。

在每次调用 step 时,环境会:

  • 接收 LLM 的输出,特别是 full 字段,其中包含到目前为止的整个对话,包括新响应(例如,history.fulltext.fulltokens.full)。

  • 将此 full 字段设置为下一个 LLM 步进的 prompt(例如,td[“next”, “history”].prompttd[“next”, “text”].prompttd[“next”, “tokens”].prompt)。

  • 可以选择应用转换以在下一个 LLM 步进之前插入新的用户消息、工具调用或其他对话修改,以优化 prompt。

这种机制支持无缝的多轮交互,并支持复杂的用例,如工具使用和奖励塑形。

特定任务的环境

我们提供了一些特定任务的环境,例如用于 GSM8K 数据集的 GSM8KEnv,用于 IFEval 数据集的 IFEvalEnv,以及用于 MLGym 集成的 MLGymEnv

这些环境包装了一个 ChatEnv,并在一个 TransformedEnv 类中添加了一个 DataLoadingPrimer 转换(以及一个可选的奖励解析转换)。

ChatEnv(*args, **kwargs)

一个用于 LLM 的聊天环境,设计为一个用于对话和 RL 的空白画布。

DatasetChatEnv(*args, **kwargs)

用于从数据集中提取查询的聊天环境的基类。

GSM8KEnv(*args, **kwargs)

GSM8K 数据集环境。

make_gsm8k_env([dataset, num_envs, repeats, ...])

一个基于 LLMEnv 的 GSM8K 环境的构建器。

GSM8KPrepareQuestion([in_keys, out_keys])

在使用 GSM8k 作为 LLMEnv 的一部分时准备 prompt 的转换。

IFEvalEnv(*args, **kwargs)

基于 IFEval 数据集的聊天环境。

IfEvalScorer(*[, instruction_ids_key, ...])

IF-Eval 任务的评分器。

IFEvalScoreData(prompt_level_strict_acc, ...)

LLMEnv(*args, **kwargs)

用于语言模型的文本生成环境。

LLMHashingEnv(*args, **kwargs)

一个使用哈希模块来识别唯一观测值的文本生成环境。

make_mlgym(*[, task, tasks, tokenizer, ...])

将 MLGymEnv 包装成 TorchRL 环境。

MLGymWrapper(*args, **kwargs)

MLGym 环境的薄包装器。

GSM8KRewardParser(tokenizer[, in_keys, ...])

用于 GSM8KEnv 或 make_gsm8k_env 的奖励解析器。

变换 (Transforms)

转换用于在将数据传递给 LLM 之前修改数据。工具通常作为转换实现,并附加到一个基础环境,如 ChatEnv

一个工具转换的例子是 PythonInterpreter 转换,它用于在 LLM 响应的上下文中执行 Python 代码。

>>> from torchrl.envs.llm.transforms import PythonInterpreter
>>> from torchrl.envs.llm import ChatEnv
>>> from tensordict import TensorDict, set_list_to_stack
>>> from transformers import AutoTokenizer
>>> from pprint import pprint
>>> set_list_to_stack(True).set()
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
>>> base_env = ChatEnv(
...     tokenizer=tokenizer,
...     system_prompt="You are an assistant that can execute Python code. Decorate your code with ```python``` tags.",
...     user_role="user",
...     system_role="system",
...     batch_size=[1],
... )
>>> env = base_env.append_transform(PythonInterpreter())
>>> env.set_seed(0)
>>> # Pass the reset data - the prompt - to the environment
>>> reset_data = env.reset(TensorDict(
...     text="Let's write a Python function that returns the square of a number.",
...     batch_size=[1])
... )
>>> # Simulate an action - i.e., a response from the LLM (as if we were an LLM)
>>> action = """Here is a block of code to be executed in python:
... ```python
... def square(x):
...     return x * x
... print('testing the square function with input 2:', square(2))
... ```
... <|im_end|>
... """
>>> step_data = reset_data.set("text_response", [action])
>>> s, s_ = env.step_and_maybe_reset(reset_data)
>>> # The history is a stack of chat messages.
>>> #  The python interpreter transform has executed the code in the last message.
>>> pprint(s_["history"].apply_chat_template(tokenizer=tokenizer))
['<|im_start|>system\n'
 'You are an assistant that can execute Python code. Decorate your code with '
 '```python``` tags.<|im_end|>\n'
 '<|im_start|>user\n'
 "Let's write a Python function that returns the square of a "
 'number.<|im_end|>\n'
 '<|im_start|>assistant\n'
 'Here is a block of code to be executed in python:\n'
 '```python\n'
 'def square(x):\n'
 '    return x * x\n'
 "print('testing the square function with input 2:', square(2))\n"
 '```<|im_end|>\n'
 '<|im_start|>user\n'
 '<tool_response>\n'
 'Code block 1 executed successfully:\n'
 'testing the square function with input 2: 4\n'
 '\n'
 '</tool_response><|im_end|>\n'
 '<|im_start|>assistant\n']

同样,从数据集中加载数据的环境只是 ChatEnv 的特殊实例,并增加了 DataLoadingPrimer 转换(以及一些专门的奖励解析转换)。

设计奖励转换

在为 LLM 环境设计奖励转换时,必须考虑几个关键因素,以确保与训练管道的正确集成。 GSM8KRewardParserIfEvalScorer 的示例为奖励转换设计提供了绝佳的模板。

奖励形状要求

奖励张量必须具有与 logits 相同的维度数,这通常比环境批次大小多两个维度。

  • 稀疏奖励:形状 (*bsz, 1, 1) - 每个序列一个奖励。

  • 密集奖励:形状 (*bsz, num_tokens, 1) - 每个 token 一个奖励。

此形状要求确保与损失计算管道兼容。例如,在 GSM8K 奖励解析器中:

# Rewards need to have shape broadcastable to [batch x tokens x 1]
tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))

Done 状态管理

妥善管理 done 状态对于防止无限生成至关重要。常见策略包括:

  1. 完成为基础的终止:当响应完成时设置 done(例如,History.complete=True)。

  2. 基于内容的终止:检测到特定内容时设置 done(例如,<answer> 块)。

  3. 基于步数的终止:使用 StepCounter 来预设步数限制。

IFEvalScorer 的示例

if self.set_done_if_answer and bool(answer_blocks):
    next_tensordict.set("done", torch.ones(...))
    next_tensordict.set("terminated", torch.ones(...))

输入模式处理

奖励转换必须正确处理不同的输入模式:

  • History 模式:从 ("history", "full")("history", "response") 中提取文本。

  • Text 模式:直接使用 ("text", "full")("text", "response") 中的文本。

  • Tokens 模式:从 ("tokens", "full")("tokens", "response") 解码 token。

GSM8K 奖励解析器演示了此模式。

if input_mode == "history":
    responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
    if hasattr(responses, "content"):
        text_completion = responses.content
elif input_mode == "text":
    text_completion = responses
elif input_mode == "tokens":
    text_completion = self.tokenizer.decode(responses.flatten(0, 1).tolist())

规范管理

准确指定奖励和观察规范对于正确初始化环境至关重要。GSM8K 和 IFEval 都提供了很好的示例。

def transform_reward_spec(self, reward_spec: Composite) -> Composite:
    shape = reward_spec.shape + (1, 1)
    reward_spec.update(
        Composite(
            reward_answer=Unbounded(shape),
            reward_think=Unbounded(shape),
            reward_right=Unbounded(shape),
            reward_contained=Unbounded(shape),
            reward=Unbounded(shape),
            success=Unbounded(shape, dtype=torch.bool),
        )
    )
    return reward_spec

批处理注意事项

为了高效处理,请妥善处理批处理数据:

  1. 展平批次维度:使用 tensordict.view(-1) 进行处理。

  2. 重塑结果:处理后恢复原始批次结构。

  3. 处理可变长度序列:使用适当的填充和屏蔽。

奖励聚合策略

考虑不同的奖励聚合方法:

  1. 简单聚合:对多个奖励组件求和或取平均。

  2. 加权聚合:对不同组件应用不同权重。

  3. 条件奖励:基于特定条件或阈值设置奖励。

IFEvalScorer 演示了复杂的聚合策略。

def default_reward_aggregator(self, score: IFEvalScoreData, ...):
    # Format score (max 1.0)
    format_score = (format_components * weights).sum(dim=-1, keepdim=True)

    # Structure score (max 1.0)
    structure_score = think_score + answer_score

    # Completion bonus (max 0.2)
    completion_bonus = float(complete) * 0.2

    return format_score + structure_score + completion_bonus

回放缓冲区中的后处理

奖励也可以通过将转换附加到回放缓冲区来事后计算。但是,done 状态捕获必须保留在环境转换中,因为它需要在数据收集期间即时发生。

错误处理和鲁棒性

实现鲁棒的错误处理以应对解析失败。

try:
    cot, potential_answer = self.extract_tags(compl)
except ET.ParseError:
    cot, potential_answer = ("", "")

性能考虑

  1. 避免冗余计算:在可能的情况下缓存解析结果。

  2. 使用高效的文本处理:根据需要利用正则表达式或 XML 解析。

  3. 最小化内存分配:重用张量并避免不必要的复制。

通过遵循这些设计原则,可以将奖励转换有效地集成到 LLM 训练管道中,同时保持性能和可靠性。

AddThinkingPrompt(cond[, prompt, ...])

一个添加思考 prompt 以鼓励 LLM 重新考虑其响应的转换。

BrowserTransform([allowed_domains, ...])

一个启用网页浏览功能的转换。

DataLoadingPrimer(*args[, use_ray_service])

一个从数据加载器加载数据并使用 stack_method 将其转换为 tensordict 的 primer。

KLComputation([gen_log_probs_full_key, ...])

一个用于计算两个对数概率张量之间的 KL 散度,并可选地将其添加到奖励中的转换。

KLRewardTransform(*args[, use_ray_service])

用于计算基于 KL 散度的奖励的旧转换。

MCPToolTransform(tools, tool_schemas[, ...])

一个在 LLM 操作响应中执行 MCP 风格工具的转换。

PolicyVersion(version_type, ] =)

一个跟踪策略版本的转换。

PythonInterpreter([tokenizer, tool_name, ...])

一个在 LLM 响应中执行 Python 代码的转换。

RayDataLoadingPrimer(*[, dataloader, ...])

一个 DataLoadingPrimer,它创建了一个可以被多个环境共享的单个 actor。

RetrieveKL(*args[, use_ray_service])

一个用于检索两个模型对数概率之间 KL 散度的转换。

RetrieveLogProb(model, *[, ...])

一个用于从模型检索对数概率以进行 KL 散度计算的转换。

TemplateTransform(tokenizer[, chat_template])

一个在正向传播期间映射应用聊天模板到输入字符串,并在反向传播期间将字符串解析回模板的转换。

Tokenizer([in_keys, out_keys, in_keys_inv, ...])

对指定输入应用分词操作。

as_nested_tensor(list_of_tensordicts)

将 tensordict 列表堆叠成具有嵌套张量的单个 tensordict。

as_padded_tensor(list_of_tensordicts[, dim, ...])

将 tensordict 列表堆叠成具有填充张量的单个 tensordict。

目标

LLM 的训练后需要专门的损失函数,这些函数经过调整以适应语言模型的独特特性。

GRPO

GRPOLoss 类是 PPOLoss 类的薄包装器,它封装了 LLM 特有的功能。

GRPOLoss(*args, **kwargs)

GRPO 损失。

GRPOLossOutput(loss_objective, ...[, ...])

MCAdvantage(grpo_size[, prompt_key, ...])

蒙特卡罗优势计算引擎。

SFT

SFTLoss(*args, **kwargs)

监督微调损失。

SFTLossOutput(loss_sft[, loss_kl_to_ref, ...])

TopKRewardSelector(total_dialog_turns, topk_size)

一个回放缓冲区转换,用于为每个 prompt 选择 top-k 奖励。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源