LLM 接口¶
TorchRL 为 LLM 的训练后和微调提供了一个全面的框架。LLM API围绕五个核心概念构建,它们共同创建一个完整的语言模型强化学习管道。
数据表示 (数据结构):处理对话、文本解析和 LLM 输出类的基础。这包括用于管理对话上下文的
History
类,以及用于 token、对数概率和文本的结构化输出类。LLM 包装器 API (模块):用于不同 LLM 后端的统一接口,包括用于 Hugging Face 模型的
TransformersWrapper
和用于 vLLM 推理的vLLMWrapper
。这些包装器在不同后端之间提供了一致的输入/输出格式,以及用于损失计算、数据存储、评分、权重同步等的集成接口。环境 (环境):管理数据加载、工具执行、奖励计算和格式化的编排层。这包括用于对话管理的
ChatEnv
、数据集环境以及用于工具集成的各种转换。目标 (目标):用于 LLM 训练的专用损失函数,包括用于组相对策略优化 (Group Relative Policy Optimization) 的
GRPOLoss
和用于监督微调 (supervised fine-tuning) 的SFTLoss
。收集器 (收集器):收集器用于从环境中收集数据并将其存储为可用于训练的格式。这包括用于从环境中收集数据的
LLMCollector
和用于使用 Ray 在分布式环境中收集数据的RayLLMCollector
。
这些组件协同工作,创建一个完整的管道:环境加载和格式化数据,LLM 包装器处理推理,数据结构维护对话上下文,目标计算训练损失。模块化设计允许您根据特定用例混合搭配组件。
LLM API 的完整示例可在 sota-implementations/grpo/ 目录中找到。训练编排涉及三个主要组件:
数据收集器:持有对环境和推理模型或引擎的引用。它收集数据,将其放入缓冲区,并处理权重更新。
回放缓冲区:存储收集到的数据并执行任何预处理或后处理步骤。这些可能包括:- 使用基于蒙特卡洛的方法进行优势估计(使用
MCAdvantage
转换);- 对输出进行评分;- 日志记录等。训练器:处理训练循环,包括优化步骤、序列化、日志记录和权重更新初始化。
警告
LLM API 仍在开发中,未来可能会发生变化。欢迎提供反馈、报告问题和提交拉取请求!
数据结构¶
数据表示层提供了以结构化方式处理对话和 LLM 输出的基础。
History 类¶
History
类是 transformers 中通常找到的聊天格式的 TensorClass 版本(请参阅 Hugging Face 聊天文档)。它提供了一个全面的 API,用于管理具有以下功能的对话数据:
文本解析和格式化:使用
from_text()
和apply_chat_template()
在文本和结构化对话格式之间进行转换动态对话构建:使用
append()
和extend()
方法追加和扩展对话多模型支持:自动检测各种模型系列的模板(Qwen、DialoGPT、Falcon、DeepSeek 等)
助手 token 掩码:识别为强化学习应用生成的 token
工具调用支持:处理对话中的函数调用和工具响应
批量操作:用于同时处理多个对话的高效张量操作。
|
|
|
支持的模型系列¶
我们目前支持以下模型系列进行字符串到 History 的解析或助手 token 掩码:
Qwen 系列(例如,Qwen/Qwen2.5-0.5B):具有完整工具调用支持的自定义模板
DialoGPT 系列(例如,microsoft/DialoGPT-medium):用于对话格式的自定义模板
Falcon 系列(例如,tiiuae/falcon-7b-instruct):用于指令格式的自定义模板
DeepSeek 系列(例如,deepseek-ai/deepseek-coder-6.7b-base):具有原生格式的自定义模板
支持其他模型,但您需要为它们提供自定义模板。LLAMA、Mistral、OPT、GPT、MPT、BLOOM、Pythia、Phi 等将使用默认的 chatml_format 模板。
用法¶
>>> from torchrl.data.llm.chat import History
>>> from transformers import AutoTokenizer
>>>
>>> # Create a conversation history
>>> history = History.from_chats([[
... {"role": "user", "content": "Hello"},
... {"role": "assistant", "content": "Hi there!"},
... {"role": "user", "content": "How are you?"},
... {"role": "assistant", "content": "I'm doing well, thanks!"}
... ]])
>>>
>>> # Load any supported tokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
>>>
>>> # Apply chat template with assistant token masking
>>> result = history.apply_chat_template(
... chat_template_name="qwen",
... add_generation_prompt=False,
... return_dict=True,
... return_assistant_tokens_mask=True,
... )
>>>
>>> # The result contains an assistant_masks tensor
>>> assistant_masks = result["assistant_masks"]
>>> print(f"Assistant tokens: {assistant_masks.sum().item()}")
添加自定义模板¶
您可以使用 torchrl.data.llm.chat.add_chat_template()
函数为新的模型系列添加自定义聊天模板。
用法示例¶
添加 Llama 模板¶
>>> from torchrl.data.llm.chat import add_chat_template, History
>>> from transformers import AutoTokenizer
>>>
>>> # Define the Llama chat template
>>> llama_template = '''
... {% for message in messages %}
... {%- if message['role'] == 'user' %}
... {{ '<s>[INST] ' + message['content'] + ' [/INST]' }}
... {%- elif message['role'] == 'assistant' %}
... {% generation %}{{ message['content'] + '</s>' }}{% endgeneration %}
... {%- endif %}
... {% endfor %}
... {%- if add_generation_prompt %}
... {% generation %}{{ ' ' }}{% endgeneration %}
... {%- endif %}
... '''
>>>
>>> # Define the inverse parser for Llama format
>>> def parse_llama_text(text: str) -> History:
... import re
... pattern = r'<s>\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)</s>'
... matches = re.findall(pattern, text, re.DOTALL)
... messages = []
... for user_content, assistant_content in matches:
... messages.append(History(role="user", content=user_content.strip()))
... messages.append(History(role="assistant", content=assistant_content.strip()))
... return lazy_stack(messages)
>>>
>>> # Add the template with auto-detection
>>> add_chat_template(
... template_name="llama",
... template=llama_template,
... inverse_parser=parse_llama_text,
... model_family_keywords=["llama", "meta-llama"]
... )
>>>
>>> # Now you can use it with auto-detection
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> history = History.from_chats([[
... {"role": "user", "content": "Hello"},
... {"role": "assistant", "content": "Hi there!"}
... ]])
>>>
>>> # Auto-detection will use the llama template
>>> result = history.apply_chat_template(
... tokenizer=tokenizer,
... add_generation_prompt=False,
... return_dict=True,
... return_assistant_tokens_mask=True,
... )
测试您的自定义模板¶
添加自定义模板时,您应该对其进行测试以确保其正常工作。以下是推荐的测试:
助手 token 掩码测试¶
测试您的模板是否支持助手 token 掩码
import pytest
from torchrl.data.llm.chat import History, add_chat_template
from transformers import AutoTokenizer
def test_my_model_assistant_masking():
"""Test that your model supports assistant token masking."""
# Add your template first
add_chat_template(
template_name="my_model",
template="your_template_here",
model_family_keywords=["my_model"]
)
tokenizer = AutoTokenizer.from_pretrained("your/model/name")
history = History.from_chats([[
{'role': 'user', 'content': 'Hello'},
{'role': 'assistant', 'content': 'Hi there!'}
]])
result = history.apply_chat_template(
tokenizer=tokenizer,
chat_template_name="my_model",
add_generation_prompt=False,
return_dict=True,
return_assistant_tokens_mask=True,
)
# Verify assistant mask is present
assert 'assistant_masks' in result
assert result['assistant_masks'].shape[0] == 1, "Should have batch dimension of 1"
assert result['assistant_masks'].shape[1] > 0, "Should have sequence length > 0"
# Verify some assistant tokens are masked
assistant_token_count = result['assistant_masks'].sum().item()
assert assistant_token_count > 0, "Should have assistant tokens masked"
print(f"✓ {assistant_token_count} assistant tokens masked")
模板等效性测试¶
测试您的自定义模板是否产生与模型默认模板相同的输出(掩码除外)
def test_my_model_template_equivalence():
"""Test that your template matches the model's default template."""
tokenizer = AutoTokenizer.from_pretrained("your/model/name")
history = History.from_chats([[
{'role': 'user', 'content': 'Hello'},
{'role': 'assistant', 'content': 'Hi there!'},
{'role': 'user', 'content': 'How are you?'},
{'role': 'assistant', 'content': 'I\'m good, thanks!'},
]])
# Get output with model's default template
try:
default_out = history.apply_chat_template(
tokenizer=tokenizer,
add_generation_prompt=False,
chat_template=tokenizer.chat_template,
tokenize=False,
)
except Exception as e:
default_out = None
print(f"[WARN] Could not get default template: {e}")
# Get output with your custom template
custom_out = history.apply_chat_template(
tokenizer=tokenizer,
add_generation_prompt=False,
chat_template_name="my_model",
tokenize=False,
)
if default_out is not None:
# Normalize whitespace for comparison
import re
def norm(s):
return re.sub(r"\s+", " ", s.strip())
assert norm(default_out) == norm(custom_out), (
f"Custom template does not match default!\n"
f"Default: {default_out}\nCustom: {custom_out}"
)
print("✓ Template equivalence verified")
else:
print("[INFO] Skipped equivalence check (no default template available)")
逆向解析测试¶
如果您提供了逆向解析器,请测试其是否正常工作
def test_my_model_inverse_parsing():
"""Test that your inverse parser works correctly."""
history = History.from_chats([[
{'role': 'user', 'content': 'Hello'},
{'role': 'assistant', 'content': 'Hi there!'}
]])
# Format using your template
formatted = history.apply_chat_template(
tokenizer=tokenizer,
chat_template_name="my_model",
add_generation_prompt=False,
tokenize=False,
)
# Parse back using your inverse parser
parsed = History.from_text(formatted, chat_template_name="my_model")
# Verify the parsing worked
assert parsed.role == history.role
assert parsed.content == history.content
print("✓ Inverse parsing verified")
LLM 包装器 API¶
LLM 包装器 API 为不同的 LLM 后端提供了统一的接口,确保跨训练和推理管道的一致输入/输出格式。主要的包装器是用于 Hugging Face 模型的 TransformersWrapper
和用于 vLLM 推理的 vLLMWrapper
。
数据结构类
包装器使用结构化的 TensorClass
对象来表示 LLM 数据的不同方面。
:class:`~torchrl.modules.llm.policies.Text`:包含具有 prompt、response 和 full 字段的文本数据
:class:`~torchrl.modules.llm.policies.ChatHistory`:包含具有 prompt、response 和 full 字段的
History
对象:class:`~torchrl.modules.llm.policies.Tokens`:包含具有 prompt、response 和 full 字段的 token 化数据
:class:`~torchrl.modules.llm.policies.LogProbs`:包含具有 prompt、response 和 full 字段的对数概率
:class:`~torchrl.modules.llm.policies.Masks`:包含注意力掩码和助手掩码
API 流程
包装器有两种不同的模式运行:
生成模式 (`generate=True`):- 输入:从 prompt 字段读取(例如,history.prompt、text.prompt、tokens.prompt)- 输出:写入 response 和 full 字段
response:仅包含新生成的内容
full:包含完整序列(prompt + response)
对数概率模式 (`generate=False`):- 输入:从 full 字段读取(例如,history.full、text.full、tokens.full)- 输出:将对数概率写入相应的 full 字段
LLM-环境交互循环

LLM-环境交互:LLM 生成响应,环境更新对话,转换可以注入新消息或工具。¶
在典型的 RL 或工具增强设置中,LLM 和环境在一个循环中交互:
LLM 生成:LLM 包装器接收一个 prompt(当前对话历史),生成一个 response,并输出一个 full 字段
包含 prompt 和 response 的连接。
环境步:环境接收 full 字段并将其作为下一个 prompt 提供给 LLM。这确保了对话
上下文在每个回合中都会增长。有关更多详细信息,请参阅 ref_env_llm_step。
转换:在下一个 LLM 步骤之前,转换可以修改对话——例如,通过插入新的用户消息、工具调用
或奖励注释。
重复:此过程根据需要重复多个回合,从而支持多回合对话、工具使用和 RL 训练。
此设计允许在每个步骤中灵活地增强对话,支持高级 RL 和工具使用场景。
典型的伪代码循环
# Get the first prompt out of an initial query
obs = env.reset(TensorDict({"query": "Hello!"}, batch_size=env.batch_size, device=env.device))
while not done:
# LLM generates a response given the current prompt
llm_output = llm(obs)
# Environment steps: creates a ("next", "history") field with the new prompt (from the previous `"full"` field)
obs = env.step(llm_output)
与 History 集成
当使用 input_mode=”history” 时,包装器与 History
类无缝集成。
输入:接收一个
ChatHistory
对象,其中 prompt 字段包含一个 History。生成:应用聊天模板将 History 转换为 token,生成响应,然后将完整文本解析回 History 对象。
输出:返回一个 ChatHistory,其中:- prompt:原始对话历史- response:包含助手响应的新 History 对象- full:完整对话历史,并附加了新的响应
此设计允许自然的对话流程,其中每个生成步骤都会扩展对话历史,使其成为多回合对话系统的理想选择。
Prompt vs. Response 和 padding¶
LLM 输出结构:Token、LogProbs 和 Masks 的填充(padded)与稀疏(sparse)表示。¶
上图说明了 TorchRL 的 LLM API 中使用的主要输出类的结构:
Tokens(以及扩展的 LogProbs):- 填充格式:批次中的所有序列都填充到相同的长度(使用特殊填充 token),使其适用于张量操作。Prompt 和 response 连接形成 tokens.full,掩码指示有效位置与填充位置的关系。- 稀疏格式:每个序列保留其原始长度(无填充),表示为张量列表。这对于可变长度数据更节省内存。
Masks:显示了两个主要掩码:- mask.attention_mask_all 标记有效(非填充)token。- mask.assistant_mask_all 标记由助手生成的 token(对于 RLHF 和 SFT 训练很有用)。
Text:未详细显示,因为它只是 prompt、response 或完整序列的解码字符串表示。
此格式确保所有 LLM 输出(Tokens、LogProbs、Masks、Text)都保持一致且易于操作,无论您使用填充还是稀疏批处理。
通常,我们建议使用未填充的数据,因为它更节省内存且更易于操作。例如,当从缓冲区收集多个填充元素时,很难清楚地了解如何重新填充它们以将它们组合成一个连贯的批次。使用未填充的数据更直接。
模块¶
LLM 包装器 API 为不同的 LLM 后端提供了统一的接口,确保跨训练和推理管道的一致输入/输出格式。
包装器¶
这些原语的主要目标是:
统一训练和推理管道中的输入/输出数据格式
统一后端之间的输入/输出数据格式(以便可以在不同损失和收集器中使用不同的后端)
提供适当的工具来在典型的 RL 设置中构建这些对象(资源分配、异步执行、权重更新等)
|
LLM 包装器基类。 |
|
Hugging Face Transformers 模型的包装器类,为文本生成和对数概率计算提供了一致的接口。 |
|
vLLM 模型的包装器类,为文本生成和对数概率计算提供了一致的接口。 |
|
|
|
|
|
|
|
|
|
工具¶
|
一个围绕 vllm.LLM 的薄包装器,用于控制其设备放置。 |
|
创建支持张量并行的 vLLM 推理引擎。 |
|
初始化一个无状态进程组以进行分布式通信。 |
|
Ray 的 vLLM 工作器。 |
收集器¶
TorchRL 提供专门的收集器类(LLMCollector
和 RayLLMCollector
),它们针对 LLM 用例进行了定制。我们还为某些推理引擎提供专用更新程序。
有关收集器 API 的更多详细信息,请参阅 ref_collectors。简而言之,收集器的理念是将管道的推理部分隔离到一个专用的类中。收集器通常以策略和环境作为输入,并在运行一个和另一个之间进行交替。在“经典”设置中,策略类似于正在训练的策略(可能包含额外的探索)。在 LLM 微调的上下文中,策略通常是专门的推理引擎,例如 vLLM 服务器。收集器由以下参数和功能定义:
同步/异步:收集器是应以同步还是异步模式运行。在同步模式下,收集器将在推理步骤与优化/训练步骤之间交替运行。在异步模式下,收集器将在推理步骤与优化/训练步骤并行运行。可以向收集器传递回放缓冲区,以便收集器可以直接写入其中。在其他情况下,可以迭代收集器来收集数据。
步数:收集器会根据一定的步数预算构建,同时在收集期间生成的每个批次中包含一定数量的步数。
权重更新程序:权重更新程序是更新策略权重的类。将权重更新隔离到一个专用的类中,可以根据策略规范轻松实现不同的权重更新策略。
策略版本跟踪¶
LLM 收集器还允许跟踪策略版本,这对于某些用例很有用。这通过将 PolicyVersion
转换添加到环境中来实现,然后由收集器在每次权重更新后递增。为此,可以在收集器构造函数中提供转换的状态版本或一个布尔值。
>>> from torchrl.envs.llm.transforms import PolicyVersion
>>> from torchrl.collectors.llm import LLMCollector
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
>>> env = make_env() # place your code here
>>> policy = make_policy() # place your code here
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
>>> # init the updater
>>> collector.weight_updater.init(...)
>>> # the version is incremented after each weight update
>>> collector.update_policy_weights_(state_dict=...)
>>> print(collector.policy_version_tracker.version)
>>> # the policy version is written in the data
>>> for data in collector:
... print(data["policy_version"])
|
将权重发送到 vLLM 工作器的类。 |
|
SyncDataCollector 的简化版本,用于 LLM 推理。 |
|
LLM 收集器的轻量级 Ray 实现,可以远程扩展和采样。 |
环境¶
环境层负责数据加载、工具执行、奖励计算和格式化。当使用 TorchRL 微调 LLM 时,环境是推理管道的关键组成部分,与策略和收集器一起。
ChatEnv
ChatEnv
是 LLM 环境的空白画布——它是一个基本工具,旨在通过添加特定功能的转换来扩展。基础 ChatEnv 使用 History
格式提供管理对话状态的基本结构,但它有意保持最小化以提供最大的灵活性。
核心功能¶
ChatEnv 在三种主要模式下运行:- History 模式:使用 History
对象进行对话管理- Text 模式:使用简单的文本字符串进行输入/输出- Tokens 模式:使用 token 化数据进行输入/输出
环境通过以下方式维护对话状态:- 重置:用可选的系统 prompt 初始化新对话- 步:接收 LLM 的响应并更新对话历史,准备下一个 prompt
基于转换的架构¶
转换是扩展 ChatEnv 以获得特定功能的主要方式:
奖励计算:用于 KL 散度奖励的
KLRewardTransform
工具执行:用于 Python 代码执行的
PythonInterpreter
,用于通用工具调用的MCPToolTransform
。数据加载:用于从数据集中加载 prompt 的
DataLoadingPrimer
思维 prompt:用于链式思考推理的
AddThinkingPrompt
策略跟踪:
PolicyVersion
用于版本控制。步数计数:使用
StepCounter
内置步数跟踪和重置管理。
与 LLM Wrapper 的集成¶
ChatEnv
设计为与 TransformersWrapper
和 vLLMWrapper
无缝协同工作。环境负责管理对话状态,而 Wrapper 负责实际的 LLM 推理,从而实现清晰的关注点分离。
每次调用 step 时,环境都会:
获取 LLM 的输出,特别是 full 字段,其中包含迄今为止的整个对话,包括新的响应(例如,history.full、text.full、tokens.full)。
将此 full 字段设置为下一个 LLM 步骤的新 prompt(例如,td[“next”, “history”].prompt、td[“next”, “text”].prompt、td[“next”, “tokens”].prompt)。
可选地,应用转换以插入新的用户消息、工具调用或其他对话修改,以在下一个 LLM 步骤之前优化提示。
此机制实现了无缝的多轮交互,并支持工具使用和奖励塑形等复杂工作流。
任务特定环境¶
我们提供了一些特定任务的环境,例如用于 GSM8K 数据集的 GSM8KEnv
、用于 IFEval 数据集的 IFEvalEnv
以及用于 MLGym 集成的 MLGymEnv
。
这些环境封装了一个 ChatEnv
,并在 TransformedEnv
类中添加了 DataLoadingPrimer
转换(以及可选的奖励解析转换)。
|
一个基于聊天的 LLM 环境,设计为一个对话和 RL 的空白画布。 |
|
用于从数据集中提取查询的聊天环境的基类。 |
|
GSM8K 数据集环境。 |
|
基于 LLMEnv 的 GSM8K 环境的构建器。 |
|
在 LLMEnv 中使用 GSM8k 时用于准备提示的转换。 |
|
基于 IFEval 数据集的聊天环境。 |
|
IF-Eval 任务的评分器。 |
|
|
|
面向语言模型的文本生成环境。 |
|
一个使用哈希模块来识别唯一观测值的文本生成环境。 |
|
将 MLGymEnv 封装到 TorchRL 环境中。 |
|
MLGym 环境的精简封装。 |
|
GSM8KEnv 或 make_gsm8k_env 的奖励解析器。 |
变换 (Transforms)¶
转换用于在数据传递给 LLM 之前修改数据。工具通常实现为转换,并附加到 ChatEnv
等基类环境。
工具转换的一个例子是 PythonInterpreter
转换,它用于在 LLM 的上下文中执行 Python 代码。
>>> from torchrl.envs.llm.transforms import PythonInterpreter
>>> from torchrl.envs.llm import ChatEnv
>>> from tensordict import TensorDict, set_list_to_stack
>>> from transformers import AutoTokenizer
>>> from pprint import pprint
>>> set_list_to_stack(True).set()
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
>>> base_env = ChatEnv(
... tokenizer=tokenizer,
... system_prompt="You are an assistant that can execute Python code. Decorate your code with ```python``` tags.",
... user_role="user",
... system_role="system",
... batch_size=[1],
... )
>>> env = base_env.append_transform(PythonInterpreter())
>>> env.set_seed(0)
>>> # Pass the reset data - the prompt - to the environment
>>> reset_data = env.reset(TensorDict(
... text="Let's write a Python function that returns the square of a number.",
... batch_size=[1])
... )
>>> # Simulate an action - i.e., a response from the LLM (as if we were an LLM)
>>> action = """Here is a block of code to be executed in python:
... ```python
... def square(x):
... return x * x
... print('testing the square function with input 2:', square(2))
... ```
... <|im_end|>
... """
>>> step_data = reset_data.set("text_response", [action])
>>> s, s_ = env.step_and_maybe_reset(reset_data)
>>> # The history is a stack of chat messages.
>>> # The python interpreter transform has executed the code in the last message.
>>> pprint(s_["history"].apply_chat_template(tokenizer=tokenizer))
['<|im_start|>system\n'
'You are an assistant that can execute Python code. Decorate your code with '
'```python``` tags.<|im_end|>\n'
'<|im_start|>user\n'
"Let's write a Python function that returns the square of a "
'number.<|im_end|>\n'
'<|im_start|>assistant\n'
'Here is a block of code to be executed in python:\n'
'```python\n'
'def square(x):\n'
' return x * x\n'
"print('testing the square function with input 2:', square(2))\n"
'```<|im_end|>\n'
'<|im_start|>user\n'
'<tool_response>\n'
'Code block 1 executed successfully:\n'
'testing the square function with input 2: 4\n'
'\n'
'</tool_response><|im_end|>\n'
'<|im_start|>assistant\n']
类似地,从数据集中加载数据的环境只是 ChatEnv
的特殊实例,并增加了 DataLoadingPrimer
转换(以及一些专门的奖励解析转换)。
设计奖励转换¶
在为 LLM 环境设计奖励转换时,必须解决几个关键考虑因素,以确保与训练流水线的正确集成。GSM8KRewardParser
和 IfEvalScorer
的示例为奖励转换设计提供了极好的模板。
奖励形状要求
奖励张量必须具有与 logits 相同的维度数,通常比环境批次大小多两个维度。
稀疏奖励:形状为
(*bsz, 1, 1)
- 每个序列一个奖励密集奖励:形状为
(*bsz, num_tokens, 1)
- 每个 token 的奖励
此形状要求确保与损失计算流水线的兼容性。例如,在 GSM8K 奖励解析器中
# Rewards need to have shape broadcastable to [batch x tokens x 1]
tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))
完成状态管理
正确管理完成状态以防止无休止生成至关重要。常见的策略包括:
基于完成的终止:当响应完成时设置完成(例如,
History.complete=True
)基于内容的终止:检测到特定内容时设置完成(例如,
<answer>
块)基于步数的终止:使用
StepCounter
进行预设的步数限制。
IFEvalScorer 的示例
if self.set_done_if_answer and bool(answer_blocks):
next_tensordict.set("done", torch.ones(...))
next_tensordict.set("terminated", torch.ones(...))
输入模式处理
奖励转换必须正确处理不同的输入模式。
历史模式:从
("history", "full")
或("history", "response")
中提取文本。文本模式:直接使用来自
("text", "full")
或("text", "response")
的文本。Token 模式:从
("tokens", "full")
或("tokens", "response")
中解码 token。
GSM8K 奖励解析器展示了这种模式。
if input_mode == "history":
responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
if hasattr(responses, "content"):
text_completion = responses.content
elif input_mode == "text":
text_completion = responses
elif input_mode == "tokens":
text_completion = self.tokenizer.decode(responses.flatten(0, 1).tolist())
规范管理
准确指定奖励和观察规范对于正确的环境初始化至关重要。GSM8K 和 IFEval 都提供了很好的示例。
def transform_reward_spec(self, reward_spec: Composite) -> Composite:
shape = reward_spec.shape + (1, 1)
reward_spec.update(
Composite(
reward_answer=Unbounded(shape),
reward_think=Unbounded(shape),
reward_right=Unbounded(shape),
reward_contained=Unbounded(shape),
reward=Unbounded(shape),
success=Unbounded(shape, dtype=torch.bool),
)
)
return reward_spec
批处理注意事项
为了高效处理,请适当处理批处理数据。
展平批次维度:使用
tensordict.view(-1)
进行处理。重塑结果:处理后恢复原始批次结构。
处理可变长度序列:使用适当的填充和掩码。
奖励聚合策略
考虑不同的奖励聚合方法。
简单聚合:对多个奖励组件求和或求平均。
加权聚合:对不同组件应用不同的权重。
条件奖励:根据特定条件或阈值确定奖励。
IFEvalScorer 展示了一种复杂的聚合策略。
def default_reward_aggregator(self, score: IFEvalScoreData, ...):
# Format score (max 1.0)
format_score = (format_components * weights).sum(dim=-1, keepdim=True)
# Structure score (max 1.0)
structure_score = think_score + answer_score
# Completion bonus (max 0.2)
completion_bonus = float(complete) * 0.2
return format_score + structure_score + completion_bonus
回放缓冲区中的后处理
奖励也可以通过将转换附加到回放缓冲区来事后计算。但是,完成状态捕获必须保留在环境转换中,因为它需要在数据收集期间实时进行。
错误处理和鲁棒性
实现鲁棒的错误处理以应对解析失败。
try:
cot, potential_answer = self.extract_tags(compl)
except ET.ParseError:
cot, potential_answer = ("", "")
性能考虑
避免重复计算:尽可能缓存解析结果。
使用高效的文本处理:根据需要利用正则表达式或 XML 解析。
最小化内存分配:重用张量并避免不必要的复制。
通过遵循这些设计原则,可以将奖励转换有效地集成到 LLM 训练流水线中,同时保持性能和可靠性。
|
一个添加思考提示以鼓励 LLM 重新考虑其响应的转换。 |
|
一个启用网页浏览功能的转换。 |
|
一个从数据加载器加载数据并使用 |
|
一个用于计算两个 log-prob 张量之间的 KL 散度并可选地将其添加到奖励的转换。 |
|
用于计算基于 KL 散度的奖励的旧版转换。 |
|
一个响应 LLM 操作执行 MCP 风格工具的转换。 |
|
一个跟踪策略版本的转换。 |
|
一个在 LLM 响应中执行 Python 代码的转换。 |
|
一个用于检索两个模型 log-probabilities 之间 KL 散度的转换。 |
|
一个从模型中检索 log-probabilities 以进行 KL 散度计算的转换。 |
|
一个在正向传播期间将聊天模板应用于输入字符串,并在反向传播期间将字符串解析到模板的转换。 |
|
对指定输入应用分词操作。 |
|
将 tensordicts 列表堆叠成一个包含嵌套张量的单个 tensordict。 |
|
将 tensordicts 列表堆叠成一个包含填充张量的单个 tensordict。 |
目标¶
LLM 训练后需要专门的损失函数,这些函数已针对语言模型的独特特征进行了调整。
GRPO¶
GRPOLoss
类是 PPOLoss
类的精简封装,它实现了 LLM 特有的功能。
|
GRPO 损失。 |
|
|
|
蒙特卡洛优势计算引擎。 |
SFT¶
|
监督微调损失。 |
|
|
一个回放缓冲区转换,用于为每个提示选择 top-k 奖励。 |