LLMCollector¶
- class torchrl.collectors.llm.LLMCollector(env: EnvBase | Callable[[], EnvBase], *, policy: Callable[[TensorDictBase], TensorDictBase] | None = None, policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]] | None = None, dialog_turns_per_batch: int | None = None, yield_only_last_steps: bool | None = None, yield_completed_trajectories: bool | None = None, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, total_dialog_turns: int = - 1, async_envs: bool | None = None, replay_buffer: ReplayBuffer | None = None, reset_at_each_iter: bool = False, flatten_data: bool | None = None, weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, queue: Any | None = None, track_policy_version: bool | PolicyVersion = False, verbose: bool = False)[源]¶
LLM 推理的 SyncDataCollector 的简化版本。
- 参数:
env (EnvBase 或 EnvBase 构造函数) – 用于数据收集的环境。
- 关键字参数:
policy (Callable[[TensorDictBase], TensorDictBase]) – 用于数据收集的策略。
policy_factory (Callable[[], Callable], optional) –
一个可调用对象,它返回一个策略实例。这与 policy 参数互斥。
注意
policy_factory 在策略无法序列化时非常有用。
dialog_turns_per_batch (int, optional) – 一个关键字参数,表示批次中的元素总数。除非 yield_completed_trajectories=True,否则始终需要此参数。
total_dialog_turns (int) – 一个关键字参数,表示收集器在其生命周期内返回的总步数。-1 表示永不结束(直到关闭)。默认为 -1。
yield_completed_trajectories (bool, optional) –
是生成具有给定步数的 rollout 批次(yield_completed_trajectories=False,默认)还是单个、完整的 trajectories(yield_completed_trajectories=True)。默认为 False,除非 yield_only_last_steps=True,此时它不能为 False。
警告
如果环境的 done 状态未正确设置,这可能导致收集器永远不产生任何数据。
yield_only_last_steps (bool, optional) –
是生成 trajectory 的每一步,还是只生成最后一步(done)。如果为 True,则一次生成(或写入缓冲区)一个 trajectory。
警告
如果环境的 done 状态未正确设置,这可能导致收集器永远不产生任何数据。
postproc (Callable, optional) – 后处理转换,例如
Transform
或MultiStep
实例。默认为None
。async_envs (bool, optional) – 如果为
True
,环境将异步运行。如果环境是AsyncEnvPool
实例,则默认为 True。replay_buffer (ReplayBuffer, optional) – 如果提供,收集器将不会生成 tensordicts,而是填充缓冲区。默认为
None
。reset_at_each_iter (bool, optional) – 如果为
True
,将在每次迭代时重置环境。flatten_data (bool, optional) – 如果为
True
,收集器将在返回前展平收集到的数据。实际上,这意味着如果使用批次大小为 (B,) 的环境并运行 T 步,则 flatten_data=True 将显示形状为 (B*T,) 的数据,而 flatten_data=False 将不显示形状为 (B, T) 的数据。如果提供了 replay_buffer,则默认为 True,否则默认为 False。weight_updater (WeightUpdaterBase 或 构造函数, optional) –
WeightUpdaterBase
或其子类的实例,负责在远程推理工作器上更新策略权重。在SyncDataCollector
中通常不使用此参数,因为它在单个进程环境中运行。如果更新器需要序列化,请考虑使用构造函数。track_policy_version (bool 或 PolicyVersion, optional) – 如果为
True
,收集器将跟踪策略的版本。这将由PolicyVersion
转换器处理,该转换器将被添加到环境中。或者,也可以传递一个PolicyVersion
实例,用于跟踪策略版本。默认为 False。verbose (bool, optional) – 如果为
True
,收集器将打印进度信息。默认为 False。
示例
>>> import vllm >>> from torchrl.modules import vLLMWrapper >>> from pytorch.rl.test.mocking_classes import DummyStrDataLoader >>> from torchrl.envs import LLMEnv >>> llm_model = vllm.LLM("gpt2") >>> tokenizer = llm_model.get_tokenizer() >>> tokenizer.pad_token = tokenizer.eos_token >>> policy = vLLMWrapper(llm_model) >>> dataloader = DummyStrDataLoader(1) >>> env = LLMEnv.from_dataloader( ... dataloader=dataloader, ... tokenizer=tokenizer, ... from_text=True, ... batch_size=1, ... group_repeats=True, ... ) >>> collector = LLMCollector( ... env=env, ... policy_factory=lambda: policy, ... dialog_turns_per_batch=env.batch_size[0], ... total_dialog_turns=3, ... ) >>> for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break LazyStackedTensorDict( fields={ attention_mask: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False), collector: LazyStackedTensorDict( fields={ traj_ids: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([1, 1]), device=None, is_shared=False, stack_dim=1), done: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False), text: NonTensorStack( [['plsgqejeyd']], batch_size=torch.Size([1, 1]), device=None), text_response: NonTensorStack( [['ec.n.n.n.tjbjz3perwhz']], batch_size=torch.Size([1, 1]), device=None), tokens: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False), tokens_response: Tensor(shape=torch.Size([1, 1, 16]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([1, 1]), device=None, is_shared=False, stack_dim=1) >>> del collector
- classmethod as_remote(remote_config: dict[str, Any] | None = None)¶
创建一个远程 ray 类的实例。
- 参数:
cls (Python Class) – 要远程实例化的类。
remote_config (dict) – 为此类保留的 CPU 核心数量。
- 返回:
一个创建 ray 远程类实例的函数。
- async_shutdown(timeout: float | None = None, close_env: bool = True) None ¶
结束 ray.init() 在异步执行期间启动的进程。
- property dialog_turns_per_batch: int¶
别名 frames_per_batch。
- get_policy_version() str | int | None [源]¶
获取当前策略版本。
此方法用于支持 Ray actors 中的远程调用,因为属性不能通过 Ray 的 RPC 机制直接访问。
- 返回:
当前版本号(int)或 UUID(str),或 None(如果禁用了版本跟踪)。
- init_updater(*args, **kwargs)¶
使用自定义参数初始化权重更新器。
此方法将参数传递给权重更新器的 init 方法。如果未设置权重更新器,则此方法无效。
- 参数:
*args – 用于权重更新器初始化的位置参数
**kwargs – 用于权重更新器初始化的关键字参数
- iterator() Iterator[TensorDictBase] ¶
迭代 DataCollector。
Yields: 包含轨迹 (块) 的 TensorDictBase 对象
- load_state_dict(state_dict: OrderedDict, **kwargs) None ¶
在环境和策略上加载 state_dict。
- 参数:
state_dict (OrderedDict) – 包含 “policy_state_dict” 和
"env_state_dict"
字段的有序字典。
- pause()¶
上下文管理器,如果收集器正在自由运行,则暂停收集器。
- property policy_version: str | int | None¶
当前策略版本。
- reset(index=None, **kwargs) None ¶
将环境重置到新的初始状态。
- property rollout: Callable[[], TensorDictBase]¶
使用提供的策略在环境中计算 rollout。
- 返回:
包含已计算 rollout 的 TensorDictBase。
- set_seed(seed: int, static_seed: bool = False) int ¶
设置 DataCollector 中存储的环境的种子。
- 参数:
seed (int) – 用于环境的种子整数。
static_seed (bool, optional) – 如果
True
,种子不会递增。默认为 False
- 返回:
输出种子。当 DataCollector 包含多个环境时,这很有用,因为种子会为每个环境递增。结果种子是最后一个环境的种子。
示例
>>> from torchrl.envs import ParallelEnv >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = ParallelEnv(6, env_fn) >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) >>> out_seed = collector.set_seed(1) # out_seed = 6
- shutdown(timeout: float | None = None, close_env: bool = True) None ¶
关闭所有工作器和/或本地环境。
- 参数:
timeout (float, optional) – 工作器之间关闭管道的超时时间。对此类无效。
close_env (bool, optional) – 是否关闭环境。默认为 True。
- start()¶
在单独的线程中启动收集器以进行异步数据收集。
收集到的数据存储在提供的回放缓冲区中。当您希望将数据收集与训练分离时,此方法非常有用,允许您的训练循环独立于数据收集过程运行。
- 抛出:
RuntimeError – 如果在收集器初始化期间未定义回放缓冲区。
示例
>>> import time >>> from functools import partial >>> >>> import tqdm >>> >>> from torchrl.collectors import SyncDataCollector, RandomPolicy >>> from torchrl.data import LazyTensorStorage, ReplayBuffer >>> from torchrl.envs import GymEnv, set_gym_backend >>> import ale_py >>> >>> # Set the gym backend to gymnasium >>> set_gym_backend("gymnasium").set() >>> >>> if __name__ == "__main__": ... # Create a random policy for the Pong environment ... env = GymEnv("ALE/Pong-v5") ... policy = RandomPolicy(env.action_spec) ... ... # Initialize a shared replay buffer ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) ... ... # Create a synchronous data collector ... collector = SyncDataCollector( ... env, ... policy=policy, ... replay_buffer=rb, ... frames_per_batch=256, ... total_frames=-1, ... ) ... ... # Progress bar to track the number of collected frames ... pbar = tqdm.tqdm(total=100_000) ... ... # Start the collector asynchronously ... collector.start() ... ... # Track the write count of the replay buffer ... prec_wc = 0 ... while True: ... wc = rb.write_count ... c = wc - prec_wc ... prec_wc = wc ... ... # Update the progress bar ... pbar.update(c) ... pbar.set_description(f"Write Count: {rb.write_count}") ... ... # Check the write count every 0.5 seconds ... time.sleep(0.5) ... ... # Stop when the desired number of frames is reached ... if rb.write_count . 100_000: ... break ... ... # Shut down the collector ... collector.async_shutdown()
- state_dict() OrderedDict ¶
返回数据收集器的本地 state_dict(环境和策略)。
- 返回:
包含
"policy_state_dict"
和 “env_state_dict” 字段的有序字典。
- update_policy_weights_(policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs) None ¶
更新数据收集器的策略权重,支持本地和远程执行上下文。
此方法确保数据收集器使用的策略权重与最新的训练权重同步。它支持本地和远程权重更新,具体取决于数据收集器的配置。本地(下载)更新在远程(上传)更新之前执行,以便可以将权重从服务器传输到子工作器。
- 参数:
policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None) – 要更新的权重。可以是: - TensorDictModuleBase:将提取其权重的策略模块 - TensorDictBase:包含权重的 TensorDict - dict:包含权重的普通 dict - None:将尝试从服务器获取权重(使用 _get_server_weights())
worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional) – 需要更新的工作器的标识符。这在收集器具有多个关联工作器时很重要。
- 抛出:
TypeError – 如果提供了 worker_ids 但未配置 weight_updater。
注意
用户应扩展 WeightUpdaterBase 类来定制特定用例的权重更新逻辑。不应覆盖此方法。
另请参阅
LocalWeightsUpdaterBase
和RemoteWeightsUpdaterBase()
。