快捷方式

LLM 接口

TorchRL 为 LLM 的训练后和微调提供了一个全面的框架。LLM API围绕五个核心概念构建,它们共同创建一个完整的语言模型强化学习管道。

  1. 数据表示 (数据结构):处理对话、文本解析和 LLM 输出类的基础。这包括用于管理对话上下文的 History 类,以及用于 token、对数概率和文本的结构化输出类。

  2. LLM 包装器 API (模块):用于不同 LLM 后端的统一接口,包括用于 Hugging Face 模型的 TransformersWrapper 和用于 vLLM 推理的 vLLMWrapper。这些包装器在不同后端之间提供了一致的输入/输出格式,以及用于损失计算、数据存储、评分、权重同步等的集成接口。

  3. 环境 (环境):管理数据加载、工具执行、奖励计算和格式化的编排层。这包括用于对话管理的 ChatEnv、数据集环境以及用于工具集成的各种转换。

  4. 目标 (目标):用于 LLM 训练的专用损失函数,包括用于组相对策略优化 (Group Relative Policy Optimization) 的 GRPOLoss 和用于监督微调 (supervised fine-tuning) 的 SFTLoss

  5. 收集器 (收集器):收集器用于从环境中收集数据并将其存储为可用于训练的格式。这包括用于从环境中收集数据的 LLMCollector 和用于使用 Ray 在分布式环境中收集数据的 RayLLMCollector

这些组件协同工作,创建一个完整的管道:环境加载和格式化数据,LLM 包装器处理推理,数据结构维护对话上下文,目标计算训练损失。模块化设计允许您根据特定用例混合搭配组件。

LLM API 的完整示例可在 sota-implementations/grpo/ 目录中找到。训练编排涉及三个主要组件:

  • 数据收集器:持有对环境和推理模型或引擎的引用。它收集数据,将其放入缓冲区,并处理权重更新。

  • 回放缓冲区:存储收集到的数据并执行任何预处理或后处理步骤。这些可能包括:- 使用基于蒙特卡洛的方法进行优势估计(使用 MCAdvantage 转换);- 对输出进行评分;- 日志记录等。

  • 训练器:处理训练循环,包括优化步骤、序列化、日志记录和权重更新初始化。

警告

LLM API 仍在开发中,未来可能会发生变化。欢迎提供反馈、报告问题和提交拉取请求!

数据结构

数据表示层提供了以结构化方式处理对话和 LLM 输出的基础。

History 类

History 类是 transformers 中通常找到的聊天格式的 TensorClass 版本(请参阅 Hugging Face 聊天文档)。它提供了一个全面的 API,用于管理具有以下功能的对话数据:

  • 文本解析和格式化:使用 from_text()apply_chat_template() 在文本和结构化对话格式之间进行转换

  • 动态对话构建:使用 append()extend() 方法追加和扩展对话

  • 多模型支持:自动检测各种模型系列的模板(Qwen、DialoGPT、Falcon、DeepSeek 等)

  • 助手 token 掩码:识别为强化学习应用生成的 token

  • 工具调用支持:处理对话中的函数调用和工具响应

  • 批量操作:用于同时处理多个对话的高效张量操作。

History(role, content[, is_complete, ...])

ContentBase(type, text, url, data, ...[, ...])

支持的模型系列

我们目前支持以下模型系列进行字符串到 History 的解析或助手 token 掩码:

  • Qwen 系列(例如,Qwen/Qwen2.5-0.5B):具有完整工具调用支持的自定义模板

  • DialoGPT 系列(例如,microsoft/DialoGPT-medium):用于对话格式的自定义模板

  • Falcon 系列(例如,tiiuae/falcon-7b-instruct):用于指令格式的自定义模板

  • DeepSeek 系列(例如,deepseek-ai/deepseek-coder-6.7b-base):具有原生格式的自定义模板

支持其他模型,但您需要为它们提供自定义模板。LLAMA、Mistral、OPT、GPT、MPT、BLOOM、Pythia、Phi 等将使用默认的 chatml_format 模板。

用法

>>> from torchrl.data.llm.chat import History
>>> from transformers import AutoTokenizer
>>>
>>> # Create a conversation history
>>> history = History.from_chats([[
...     {"role": "user", "content": "Hello"},
...     {"role": "assistant", "content": "Hi there!"},
...     {"role": "user", "content": "How are you?"},
...     {"role": "assistant", "content": "I'm doing well, thanks!"}
... ]])
>>>
>>> # Load any supported tokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
>>>
>>> # Apply chat template with assistant token masking
>>> result = history.apply_chat_template(
...     chat_template_name="qwen",
...     add_generation_prompt=False,
...     return_dict=True,
...     return_assistant_tokens_mask=True,
... )
>>>
>>> # The result contains an assistant_masks tensor
>>> assistant_masks = result["assistant_masks"]
>>> print(f"Assistant tokens: {assistant_masks.sum().item()}")

添加自定义模板

您可以使用 torchrl.data.llm.chat.add_chat_template() 函数为新的模型系列添加自定义聊天模板。

用法示例

添加 Llama 模板
>>> from torchrl.data.llm.chat import add_chat_template, History
>>> from transformers import AutoTokenizer
>>>
>>> # Define the Llama chat template
>>> llama_template = '''
... {% for message in messages %}
... {%- if message['role'] == 'user' %}
... {{ '<s>[INST] ' + message['content'] + ' [/INST]' }}
... {%- elif message['role'] == 'assistant' %}
... {% generation %}{{ message['content'] + '</s>' }}{% endgeneration %}
... {%- endif %}
... {% endfor %}
... {%- if add_generation_prompt %}
... {% generation %}{{ ' ' }}{% endgeneration %}
... {%- endif %}
... '''
>>>
>>> # Define the inverse parser for Llama format
>>> def parse_llama_text(text: str) -> History:
...     import re
...     pattern = r'<s>\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)</s>'
...     matches = re.findall(pattern, text, re.DOTALL)
...     messages = []
...     for user_content, assistant_content in matches:
...         messages.append(History(role="user", content=user_content.strip()))
...         messages.append(History(role="assistant", content=assistant_content.strip()))
...     return lazy_stack(messages)
>>>
>>> # Add the template with auto-detection
>>> add_chat_template(
...     template_name="llama",
...     template=llama_template,
...     inverse_parser=parse_llama_text,
...     model_family_keywords=["llama", "meta-llama"]
... )
>>>
>>> # Now you can use it with auto-detection
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> history = History.from_chats([[
...     {"role": "user", "content": "Hello"},
...     {"role": "assistant", "content": "Hi there!"}
... ]])
>>>
>>> # Auto-detection will use the llama template
>>> result = history.apply_chat_template(
...     tokenizer=tokenizer,
...     add_generation_prompt=False,
...     return_dict=True,
...     return_assistant_tokens_mask=True,
... )

测试您的自定义模板

添加自定义模板时,您应该对其进行测试以确保其正常工作。以下是推荐的测试:

助手 token 掩码测试

测试您的模板是否支持助手 token 掩码

import pytest
from torchrl.data.llm.chat import History, add_chat_template
from transformers import AutoTokenizer

def test_my_model_assistant_masking():
    """Test that your model supports assistant token masking."""
    # Add your template first
    add_chat_template(
        template_name="my_model",
        template="your_template_here",
        model_family_keywords=["my_model"]
    )

    tokenizer = AutoTokenizer.from_pretrained("your/model/name")
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'}
    ]])

    result = history.apply_chat_template(
        tokenizer=tokenizer,
        chat_template_name="my_model",
        add_generation_prompt=False,
        return_dict=True,
        return_assistant_tokens_mask=True,
    )

    # Verify assistant mask is present
    assert 'assistant_masks' in result
    assert result['assistant_masks'].shape[0] == 1, "Should have batch dimension of 1"
    assert result['assistant_masks'].shape[1] > 0, "Should have sequence length > 0"

    # Verify some assistant tokens are masked
    assistant_token_count = result['assistant_masks'].sum().item()
    assert assistant_token_count > 0, "Should have assistant tokens masked"
    print(f"✓ {assistant_token_count} assistant tokens masked")
模板等效性测试

测试您的自定义模板是否产生与模型默认模板相同的输出(掩码除外)

def test_my_model_template_equivalence():
    """Test that your template matches the model's default template."""
    tokenizer = AutoTokenizer.from_pretrained("your/model/name")
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'},
        {'role': 'user', 'content': 'How are you?'},
        {'role': 'assistant', 'content': 'I\'m good, thanks!'},
    ]])

    # Get output with model's default template
    try:
        default_out = history.apply_chat_template(
            tokenizer=tokenizer,
            add_generation_prompt=False,
            chat_template=tokenizer.chat_template,
            tokenize=False,
        )
    except Exception as e:
        default_out = None
        print(f"[WARN] Could not get default template: {e}")

    # Get output with your custom template
    custom_out = history.apply_chat_template(
        tokenizer=tokenizer,
        add_generation_prompt=False,
        chat_template_name="my_model",
        tokenize=False,
    )

    if default_out is not None:
        # Normalize whitespace for comparison
        import re
        def norm(s):
            return re.sub(r"\s+", " ", s.strip())

        assert norm(default_out) == norm(custom_out), (
            f"Custom template does not match default!\n"
            f"Default: {default_out}\nCustom: {custom_out}"
        )
        print("✓ Template equivalence verified")
    else:
        print("[INFO] Skipped equivalence check (no default template available)")
逆向解析测试

如果您提供了逆向解析器,请测试其是否正常工作

def test_my_model_inverse_parsing():
    """Test that your inverse parser works correctly."""
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'}
    ]])

    # Format using your template
    formatted = history.apply_chat_template(
        tokenizer=tokenizer,
        chat_template_name="my_model",
        add_generation_prompt=False,
        tokenize=False,
    )

    # Parse back using your inverse parser
    parsed = History.from_text(formatted, chat_template_name="my_model")

    # Verify the parsing worked
    assert parsed.role == history.role
    assert parsed.content == history.content
    print("✓ Inverse parsing verified")

LLM 包装器 API

LLM 包装器 API 为不同的 LLM 后端提供了统一的接口,确保跨训练和推理管道的一致输入/输出格式。主要的包装器是用于 Hugging Face 模型的 TransformersWrapper 和用于 vLLM 推理的 vLLMWrapper

数据结构类

包装器使用结构化的 TensorClass 对象来表示 LLM 数据的不同方面。

  • :class:`~torchrl.modules.llm.policies.Text`:包含具有 promptresponsefull 字段的文本数据

  • :class:`~torchrl.modules.llm.policies.ChatHistory`:包含具有 promptresponsefull 字段的 History 对象

  • :class:`~torchrl.modules.llm.policies.Tokens`:包含具有 promptresponsefull 字段的 token 化数据

  • :class:`~torchrl.modules.llm.policies.LogProbs`:包含具有 promptresponsefull 字段的对数概率

  • :class:`~torchrl.modules.llm.policies.Masks`:包含注意力掩码和助手掩码

API 流程

包装器有两种不同的模式运行:

生成模式 (`generate=True`):- 输入:从 prompt 字段读取(例如,history.prompttext.prompttokens.prompt)- 输出:写入 responsefull 字段

  • response:仅包含新生成的​​内容

  • full:包含完整序列(prompt + response)

对数概率模式 (`generate=False`):- 输入:从 full 字段读取(例如,history.fulltext.fulltokens.full)- 输出:将对数概率写入相应的 full 字段

LLM-环境交互循环

LLM-Environment interaction loop

LLM-环境交互:LLM 生成响应,环境更新对话,转换可以注入新消息或工具。

在典型的 RL 或工具增强设置中,LLM 和环境在一个循环中交互:

  1. LLM 生成:LLM 包装器接收一个 prompt(当前对话历史),生成一个 response,并输出一个 full 字段

包含 prompt 和 response 的连接。

  1. 环境步:环境接收 full 字段并将其作为下一个 prompt 提供给 LLM。这确保了对话

上下文在每个回合中都会增长。有关更多详细信息,请参阅 ref_env_llm_step

  1. 转换:在下一个 LLM 步骤之前,转换可以修改对话——例如,通过插入新的用户消息、工具调用

或奖励注释。

  1. 重复:此过程根据需要重复多个回合,从而支持多回合对话、工具使用和 RL 训练。

此设计允许在每个步骤中灵活地增强对话,支持高级 RL 和工具使用场景。

典型的伪代码循环

# Get the first prompt out of an initial query
obs = env.reset(TensorDict({"query": "Hello!"}, batch_size=env.batch_size, device=env.device))
while not done:
    # LLM generates a response given the current prompt
    llm_output = llm(obs)
    # Environment steps: creates a ("next", "history") field with the new prompt (from the previous `"full"` field)
    obs = env.step(llm_output)

与 History 集成

当使用 input_mode=”history” 时,包装器与 History 类无缝集成。

  • 输入:接收一个 ChatHistory 对象,其中 prompt 字段包含一个 History。

  • 生成:应用聊天模板将 History 转换为 token,生成响应,然后将完整文本解析回 History 对象。

  • 输出:返回一个 ChatHistory,其中:- prompt:原始对话历史- response:包含助手响应的新 History 对象- full:完整对话历史,并附加了新的响应

此设计允许自然的对话流程,其中每个生成步骤都会扩展对话历史,使其成为多回合对话系统的理想选择。

Prompt vs. Response 和 padding

LLM output data format (Tokens, Masks, Padded vs. Sparse)

LLM 输出结构:Token、LogProbs 和 Masks 的填充(padded)与稀疏(sparse)表示。

上图说明了 TorchRL 的 LLM API 中使用的主要输出类的结构:

  • Tokens(以及扩展的 LogProbs):- 填充格式:批次中的所有序列都填充到相同的长度(使用特殊填充 token),使其适用于张量操作。Prompt 和 response 连接形成 tokens.full,掩码指示有效位置与填充位置的关系。- 稀疏格式:每个序列保留其原始长度(无填充),表示为张量列表。这对于可变长度数据更节省内存。

  • Masks:显示了两个主要掩码:- mask.attention_mask_all 标记有效(非填充)token。- mask.assistant_mask_all 标记由助手生成的 token(对于 RLHF 和 SFT 训练很有用)。

  • Text:未详细显示,因为它只是 prompt、response 或完整序列的解码字符串表示。

此格式确保所有 LLM 输出(Tokens、LogProbs、Masks、Text)都保持一致且易于操作,无论您使用填充还是稀疏批处理。

通常,我们建议使用未填充的数据,因为它更节省内存且更易于操作。例如,当从缓冲区收集多个填充元素时,很难清楚地了解如何重新填充它们以将它们组合成一个连贯的批次。使用未填充的数据更直接。

模块

LLM 包装器 API 为不同的 LLM 后端提供了统一的接口,确保跨训练和推理管道的一致输入/输出格式。

包装器

这些原语的主要目标是:

  • 统一训练和推理管道中的输入/输出数据格式

  • 统一后端之间的输入/输出数据格式(以便可以在不同损失和收集器中使用不同的后端)

  • 提供适当的工具来在典型的 RL 设置中构建这些对象(资源分配、异步执行、权重更新等)

LLMWrapperBase(*args, **kwargs)

LLM 包装器基类。

TransformersWrapper(*args, **kwargs)

Hugging Face Transformers 模型的包装器类,为文本生成和对数概率计算提供了一致的接口。

vLLMWrapper(*args, **kwargs)

vLLM 模型的包装器类,为文本生成和对数概率计算提供了一致的接口。

ChatHistory([prompt, response, full, ...])

Text([prompt, response, full, device, names])

LogProbs([prompt, response, full, padded, ...])

Masks([all_attention_mask, ...])

Tokens([prompt, response, full, padded, ...])

工具

LLMOnDevice(*args[, bundle_indices])

一个围绕 vllm.LLM 的薄包装器,用于控制其设备放置。

make_vllm_worker(*, model_name[, devices, ...])

创建支持张量并行的 vLLM 推理引擎。

stateless_init_process_group(master_address, ...)

初始化一个无状态进程组以进行分布式通信。

vLLMWorker(*args, **kwargs)

Ray 的 vLLM 工作器。

收集器

TorchRL 提供专门的收集器类(LLMCollectorRayLLMCollector),它们针对 LLM 用例进行了定制。我们还为某些推理引擎提供专用更新程序。

有关收集器 API 的更多详细信息,请参阅 ref_collectors。简而言之,收集器的理念是将管道的推理部分隔离到一个专用的类中。收集器通常以策略和环境作为输入,并在运行一个和另一个之间进行交替。在“经典”设置中,策略类似于正在训练的策略(可能包含额外的探索)。在 LLM 微调的上下文中,策略通常是专门的推理引擎,例如 vLLM 服务器。收集器由以下参数和功能定义:

  • 同步/异步:收集器是应以同步还是异步模式运行。在同步模式下,收集器将在推理步骤与优化/训练步骤之间交替运行。在异步模式下,收集器将在推理步骤与优化/训练步骤并行运行。可以向收集器传递回放缓冲区,以便收集器可以直接写入其中。在其他情况下,可以迭代收集器来收集数据。

  • 步数:收集器会根据一定的步数预算构建,同时在收集期间生成的每个批次中包含一定数量的步数。

  • 权重更新程序:权重更新程序是更新策略权重的类。将权重更新隔离到一个专用的类中,可以根据策略规范轻松实现不同的权重更新策略。

策略版本跟踪

LLM 收集器还允许跟踪策略版本,这对于某些用例很有用。这通过将 PolicyVersion 转换添加到环境中来实现,然后由收集器在每次权重更新后递增。为此,可以在收集器构造函数中提供转换的状态版本或一个布尔值。

>>> from torchrl.envs.llm.transforms import PolicyVersion
>>> from torchrl.collectors.llm import LLMCollector
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
>>> env = make_env() # place your code here
>>> policy = make_policy() # place your code here
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
>>> # init the updater
>>> collector.weight_updater.init(...)
>>> # the version is incremented after each weight update
>>> collector.update_policy_weights_(state_dict=...)
>>> print(collector.policy_version_tracker.version)
>>> # the policy version is written in the data
>>> for data in collector:
...     print(data["policy_version"])

vLLMUpdater([master_address, master_port, ...])

将权重发送到 vLLM 工作器的类。

LLMCollector(env, *[, policy, ...])

SyncDataCollector 的简化版本,用于 LLM 推理。

RayLLMCollector(env, *[, policy, ...])

LLM 收集器的轻量级 Ray 实现,可以远程扩展和采样。

环境

环境层负责数据加载、工具执行、奖励计算和格式化。当使用 TorchRL 微调 LLM 时,环境是推理管道的关键组成部分,与策略和收集器一起。

ChatEnv

ChatEnv 是 LLM 环境的空白画布——它是一个基本工具,旨在通过添加特定功能的转换来扩展。基础 ChatEnv 使用 History 格式提供管理对话状态的基本结构,但它有意保持最小化以提供最大的灵活性。

核心功能

ChatEnv 在三种主要模式下运行:- History 模式:使用 History 对象进行对话管理- Text 模式:使用简单的文本字符串进行输入/输出- Tokens 模式:使用 token 化数据进行输入/输出

环境通过以下方式维护对话状态:- 重置:用可选的系统 prompt 初始化新对话- :接收 LLM 的响应并更新对话历史,准备下一个 prompt

基于转换的架构

转换是扩展 ChatEnv 以获得特定功能的​​主要方式:

与 LLM Wrapper 的集成

ChatEnv 设计为与 TransformersWrappervLLMWrapper 无缝协同工作。环境负责管理对话状态,而 Wrapper 负责实际的 LLM 推理,从而实现清晰的关注点分离。

每次调用 step 时,环境都会:

  • 获取 LLM 的输出,特别是 full 字段,其中包含迄今为止的整个对话,包括新的响应(例如,history.fulltext.fulltokens.full)。

  • 将此 full 字段设置为下一个 LLM 步骤的新 prompt(例如,td[“next”, “history”].prompttd[“next”, “text”].prompttd[“next”, “tokens”].prompt)。

  • 可选地,应用转换以插入新的用户消息、工具调用或其他对话修改,以在下一个 LLM 步骤之前优化提示。

此机制实现了无缝的多轮交互,并支持工具使用和奖励塑形等复杂工作流。

任务特定环境

我们提供了一些特定任务的环境,例如用于 GSM8K 数据集的 GSM8KEnv、用于 IFEval 数据集的 IFEvalEnv 以及用于 MLGym 集成的 MLGymEnv

这些环境封装了一个 ChatEnv,并在 TransformedEnv 类中添加了 DataLoadingPrimer 转换(以及可选的奖励解析转换)。

ChatEnv(*args, **kwargs)

一个基于聊天的 LLM 环境,设计为一个对话和 RL 的空白画布。

DatasetChatEnv(*args, **kwargs)

用于从数据集中提取查询的聊天环境的基类。

GSM8KEnv(*args, **kwargs)

GSM8K 数据集环境。

make_gsm8k_env([dataset, num_envs, repeats, ...])

基于 LLMEnv 的 GSM8K 环境的构建器。

GSM8KPrepareQuestion([in_keys, out_keys])

在 LLMEnv 中使用 GSM8k 时用于准备提示的转换。

IFEvalEnv(*args, **kwargs)

基于 IFEval 数据集的聊天环境。

IfEvalScorer(*[, instruction_ids_key, ...])

IF-Eval 任务的评分器。

IFEvalScoreData(prompt_level_strict_acc, ...)

LLMEnv(*args, **kwargs)

面向语言模型的文本生成环境。

LLMHashingEnv(*args, **kwargs)

一个使用哈希模块来识别唯一观测值的文本生成环境。

make_mlgym(*[, task, tasks, tokenizer, ...])

将 MLGymEnv 封装到 TorchRL 环境中。

MLGymWrapper(*args, **kwargs)

MLGym 环境的精简封装。

GSM8KRewardParser(tokenizer[, in_keys, ...])

GSM8KEnv 或 make_gsm8k_env 的奖励解析器。

变换 (Transforms)

转换用于在数据传递给 LLM 之前修改数据。工具通常实现为转换,并附加到 ChatEnv 等基类环境。

工具转换的一个例子是 PythonInterpreter 转换,它用于在 LLM 的上下文中执行 Python 代码。

>>> from torchrl.envs.llm.transforms import PythonInterpreter
>>> from torchrl.envs.llm import ChatEnv
>>> from tensordict import TensorDict, set_list_to_stack
>>> from transformers import AutoTokenizer
>>> from pprint import pprint
>>> set_list_to_stack(True).set()
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
>>> base_env = ChatEnv(
...     tokenizer=tokenizer,
...     system_prompt="You are an assistant that can execute Python code. Decorate your code with ```python``` tags.",
...     user_role="user",
...     system_role="system",
...     batch_size=[1],
... )
>>> env = base_env.append_transform(PythonInterpreter())
>>> env.set_seed(0)
>>> # Pass the reset data - the prompt - to the environment
>>> reset_data = env.reset(TensorDict(
...     text="Let's write a Python function that returns the square of a number.",
...     batch_size=[1])
... )
>>> # Simulate an action - i.e., a response from the LLM (as if we were an LLM)
>>> action = """Here is a block of code to be executed in python:
... ```python
... def square(x):
...     return x * x
... print('testing the square function with input 2:', square(2))
... ```
... <|im_end|>
... """
>>> step_data = reset_data.set("text_response", [action])
>>> s, s_ = env.step_and_maybe_reset(reset_data)
>>> # The history is a stack of chat messages.
>>> #  The python interpreter transform has executed the code in the last message.
>>> pprint(s_["history"].apply_chat_template(tokenizer=tokenizer))
['<|im_start|>system\n'
 'You are an assistant that can execute Python code. Decorate your code with '
 '```python``` tags.<|im_end|>\n'
 '<|im_start|>user\n'
 "Let's write a Python function that returns the square of a "
 'number.<|im_end|>\n'
 '<|im_start|>assistant\n'
 'Here is a block of code to be executed in python:\n'
 '```python\n'
 'def square(x):\n'
 '    return x * x\n'
 "print('testing the square function with input 2:', square(2))\n"
 '```<|im_end|>\n'
 '<|im_start|>user\n'
 '<tool_response>\n'
 'Code block 1 executed successfully:\n'
 'testing the square function with input 2: 4\n'
 '\n'
 '</tool_response><|im_end|>\n'
 '<|im_start|>assistant\n']

类似地,从数据集中加载数据的环境只是 ChatEnv 的特殊实例,并增加了 DataLoadingPrimer 转换(以及一些专门的奖励解析转换)。

设计奖励转换

在为 LLM 环境设计奖励转换时,必须解决几个关键考虑因素,以确保与训练流水线的正确集成。GSM8KRewardParserIfEvalScorer 的示例为奖励转换设计提供了极好的模板。

奖励形状要求

奖励张量必须具有与 logits 相同的维度数,通常比环境批次大小多两个维度。

  • 稀疏奖励:形状为 (*bsz, 1, 1) - 每个序列一个奖励

  • 密集奖励:形状为 (*bsz, num_tokens, 1) - 每个 token 的奖励

此形状要求确保与损失计算流水线的兼容性。例如,在 GSM8K 奖励解析器中

# Rewards need to have shape broadcastable to [batch x tokens x 1]
tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))

完成状态管理

正确管理完成状态以防止无休止生成至关重要。常见的策略包括:

  1. 基于完成的终止:当响应完成时设置完成(例如,History.complete=True

  2. 基于内容的终止:检测到特定内容时设置完成(例如,<answer> 块)

  3. 基于步数的终止:使用 StepCounter 进行预设的步数限制。

IFEvalScorer 的示例

if self.set_done_if_answer and bool(answer_blocks):
    next_tensordict.set("done", torch.ones(...))
    next_tensordict.set("terminated", torch.ones(...))

输入模式处理

奖励转换必须正确处理不同的输入模式。

  • 历史模式:从 ("history", "full")("history", "response") 中提取文本。

  • 文本模式:直接使用来自 ("text", "full")("text", "response") 的文本。

  • Token 模式:从 ("tokens", "full")("tokens", "response") 中解码 token。

GSM8K 奖励解析器展示了这种模式。

if input_mode == "history":
    responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
    if hasattr(responses, "content"):
        text_completion = responses.content
elif input_mode == "text":
    text_completion = responses
elif input_mode == "tokens":
    text_completion = self.tokenizer.decode(responses.flatten(0, 1).tolist())

规范管理

准确指定奖励和观察规范对于正确的环境初始化至关重要。GSM8K 和 IFEval 都提供了很好的示例。

def transform_reward_spec(self, reward_spec: Composite) -> Composite:
    shape = reward_spec.shape + (1, 1)
    reward_spec.update(
        Composite(
            reward_answer=Unbounded(shape),
            reward_think=Unbounded(shape),
            reward_right=Unbounded(shape),
            reward_contained=Unbounded(shape),
            reward=Unbounded(shape),
            success=Unbounded(shape, dtype=torch.bool),
        )
    )
    return reward_spec

批处理注意事项

为了高效处理,请适当处理批处理数据。

  1. 展平批次维度:使用 tensordict.view(-1) 进行处理。

  2. 重塑结果:处理后恢复原始批次结构。

  3. 处理可变长度序列:使用适当的填充和掩码。

奖励聚合策略

考虑不同的奖励聚合方法。

  1. 简单聚合:对多个奖励组件求和或求平均。

  2. 加权聚合:对不同组件应用不同的权重。

  3. 条件奖励:根据特定条件或阈值确定奖励。

IFEvalScorer 展示了一种复杂的聚合策略。

def default_reward_aggregator(self, score: IFEvalScoreData, ...):
    # Format score (max 1.0)
    format_score = (format_components * weights).sum(dim=-1, keepdim=True)

    # Structure score (max 1.0)
    structure_score = think_score + answer_score

    # Completion bonus (max 0.2)
    completion_bonus = float(complete) * 0.2

    return format_score + structure_score + completion_bonus

回放缓冲区中的后处理

奖励也可以通过将转换附加到回放缓冲区来事后计算。但是,完成状态捕获必须保留在环境转换中,因为它需要在数据收集期间实时进行。

错误处理和鲁棒性

实现鲁棒的错误处理以应对解析失败。

try:
    cot, potential_answer = self.extract_tags(compl)
except ET.ParseError:
    cot, potential_answer = ("", "")

性能考虑

  1. 避免重复计算:尽可能缓存解析结果。

  2. 使用高效的文本处理:根据需要利用正则表达式或 XML 解析。

  3. 最小化内存分配:重用张量并避免不必要的复制。

通过遵循这些设计原则,可以将奖励转换有效地集成到 LLM 训练流水线中,同时保持性能和可靠性。

AddThinkingPrompt(cond[, prompt, ...])

一个添加思考提示以鼓励 LLM 重新考虑其响应的转换。

BrowserTransform([allowed_domains, ...])

一个启用网页浏览功能的转换。

DataLoadingPrimer(dataloader, *[, primers, ...])

一个从数据加载器加载数据并使用 stack_method 将其转换为 tensordict 的 primer。

KLComputation([gen_log_probs_full_key, ...])

一个用于计算两个 log-prob 张量之间的 KL 散度并可选地将其添加到奖励的转换。

KLRewardTransform(ref_model, *[, coef, ...])

用于计算基于 KL 散度的奖励的旧版转换。

MCPToolTransform(tools, tool_schemas[, ...])

一个响应 LLM 操作执行 MCP 风格工具的转换。

PolicyVersion(version_type, ] =)

一个跟踪策略版本的转换。

PythonInterpreter([tokenizer, tool_name, ...])

一个在 LLM 响应中执行 Python 代码的转换。

RetrieveKL([gen_model, ref_model, ...])

一个用于检索两个模型 log-probabilities 之间 KL 散度的转换。

RetrieveLogProb(model, *[, ...])

一个从模型中检索 log-probabilities 以进行 KL 散度计算的转换。

TemplateTransform(tokenizer[, chat_template])

一个在正向传播期间将聊天模板应用于输入字符串,并在反向传播期间将字符串解析到模板的转换。

Tokenizer([in_keys, out_keys, in_keys_inv, ...])

对指定输入应用分词操作。

as_nested_tensor(list_of_tensordicts)

将 tensordicts 列表堆叠成一个包含嵌套张量的单个 tensordict。

as_padded_tensor(list_of_tensordicts[, dim, ...])

将 tensordicts 列表堆叠成一个包含填充张量的单个 tensordict。

目标

LLM 训练后需要专门的损失函数,这些函数已针对语言模型的独特特征进行了调整。

GRPO

GRPOLoss 类是 PPOLoss 类的精简封装,它实现了 LLM 特有的功能。

GRPOLoss(*args, **kwargs)

GRPO 损失。

GRPOLossOutput(loss_objective, ...[, ...])

MCAdvantage(grpo_size[, prompt_key, ...])

蒙特卡洛优势计算引擎。

SFT

SFTLoss(*args, **kwargs)

监督微调损失。

SFTLossOutput(loss_sft[, loss_kl_to_ref, ...])

TopKRewardSelector(total_dialog_turns, topk_size)

一个回放缓冲区转换,用于为每个提示选择 top-k 奖励。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源