RayLLMCollector¶
- class torchrl.collectors.llm.RayLLMCollector(env: EnvBase | Callable[[], EnvBase], *, policy: Callable[[TensorDictBase], TensorDictBase] | None = None, policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]] | None = None, dialog_turns_per_batch: int, total_dialog_turns: int = - 1, yield_only_last_steps: bool | None = None, yield_completed_trajectories: bool | None = None, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, async_envs: bool | None = None, replay_buffer: ReplayBuffer | None = None, reset_at_each_iter: bool = False, flatten_data: bool | None = None, weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, ray_init_config: dict[str, Any] | None = None, remote_config: dict[str, Any] | None = None, track_policy_version: bool | PolicyVersion = False, sync_iter: bool = True, verbose: bool = False)[source]¶
一个轻量级的 Ray 实现的 LLM Collector,可以远程扩展和采样。
- 参数:
env (EnvBase 或 EnvBase 构造函数) – 用于数据收集的环境。
- 关键字参数:
policy (Callable[[TensorDictBase], TensorDictBase]) – 用于数据收集的策略。
policy_factory (Callable[[], Callable], optional) – 返回策略实例的可调用对象。这与 policy 参数互斥。
dialog_turns_per_batch (int) – 一个关键字参数,表示批次中的总元素数量。
total_dialog_turns (int) – 一个关键字参数,表示 Collector 在其生命周期内返回的总对话轮次。
yield_only_last_steps (bool, optional) – 是生成轨迹的每一步,还是只生成最后(完成)的步骤。
yield_completed_trajectories (bool, optional) – 是生成具有给定步数的 rollout 批次,还是生成单个、已完成的轨迹。
postproc (Callable, optional) – 一个后处理转换。
async_envs (bool, optional) – 如果为 True,环境将异步运行。
replay_buffer (ReplayBuffer, optional) – 如果提供,Collector 将不会 yield tensordicts,而是填充缓冲区。
reset_at_each_iter (bool, optional) – 如果为 True,环境将在每次迭代时重置。
flatten_data (bool, optional) – 如果为 True,Collector 将在返回前展平收集到的数据。
weight_updater (WeightUpdaterBase 或 构造函数, optional) – WeightUpdaterBase 或其子类的实例,负责在远程推理工作节点上更新策略权重。
ray_init_config (dict[str, Any], optional) – 传递给 ray.init() 的关键字参数。
remote_config (dict[str, Any], optional) – 传递给 cls.as_remote() 的关键字参数。
sync_iter (bool, optional) –
如果为 True,Collector yield 的项将被同步到本地进程。如果为 False,Collector 将在 yield 之间收集下一批数据。这在通过
start()
方法收集数据时无效。例如:>>> collector = RayLLMCollector(..., sync_iter=True) >>> for data in collector: # blocking ... # expensive operation - collector is idle >>> collector = RayLLMCollector(..., sync_iter=False) >>> for data in collector: # non-blocking ... # expensive operation - collector is collecting data
这在某种程度上等同于使用
MultiSyncDataCollector
(sync_iter=True) 或MultiAsyncDataCollector
(sync_iter=False)。默认为 True。verbose (bool, optional) – 如果为
True
,Collector 将打印进度信息。默认为 False。
- classmethod as_remote(remote_config: dict[str, Any] | None = None)¶
创建一个远程 ray 类的实例。
- 参数:
cls (Python Class) – 要远程实例化的类。
remote_config (dict) – 为此类保留的 CPU 核心数量。
- 返回:
一个创建 ray 远程类实例的函数。
- property dialog_turns_per_batch: int¶
每批对话轮次数。
- get_policy_model()¶
获取策略模型。
此方法由 RayLLMCollector 用于获取远程 LLM 实例以进行权重更新。
- 返回:
策略模型实例
- get_policy_version() str | int | None ¶
获取当前策略版本。
此方法是为了支持 Ray actor 中的远程调用而存在的,因为属性不能通过 Ray 的 RPC 机制直接访问。
- 返回:
当前版本号(int)或 UUID(str),如果禁用版本跟踪则为 None。
- init_updater(*args, **kwargs)[source]¶
使用自定义参数初始化权重更新器。
此方法调用远程 Collector 上的 init_updater。
- 参数:
*args – 用于权重更新器初始化的位置参数
**kwargs – 用于权重更新器初始化的关键字参数
- is_initialized() bool ¶
检查 Collector 是否已初始化并准备就绪。
- 返回:
如果 Collector 已初始化并准备好收集数据,则返回 True。
- 返回类型:
布尔值
- iterator() Iterator[TensorDictBase] ¶
迭代 DataCollector。
Yields: 包含轨迹 (块) 的 TensorDictBase 对象
- load_state_dict(state_dict: OrderedDict, **kwargs) None ¶
在环境和策略上加载 state_dict。
- 参数:
state_dict (OrderedDict) – 包含 “policy_state_dict” 和
"env_state_dict"
字段的有序字典。
- pause()¶
上下文管理器,如果收集器正在自由运行,则暂停收集器。
- property policy_version: str | int | None¶
策略的当前版本。
- 返回:
当前版本号(int)或 UUID(str),如果禁用版本跟踪则为 None。
- reset(index=None, **kwargs) None ¶
将环境重置到新的初始状态。
- property rollout: Callable[[], TensorDictBase]¶
返回 rollout 函数。
- set_seed(seed: int, static_seed: bool = False) int ¶
设置 DataCollector 中存储的环境的种子。
- 参数:
seed (int) – 用于环境的种子整数。
static_seed (bool, optional) – 如果
True
,种子不会递增。默认为 False
- 返回:
输出种子。当 DataCollector 包含多个环境时,这很有用,因为种子会为每个环境递增。结果种子是最后一个环境的种子。
示例
>>> from torchrl.envs import ParallelEnv >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = ParallelEnv(6, env_fn) >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) >>> out_seed = collector.set_seed(1) # out_seed = 6
- state_dict() OrderedDict ¶
返回数据收集器的本地 state_dict(环境和策略)。
- 返回:
包含
"policy_state_dict"
和 “env_state_dict” 字段的有序字典。
- property total_dialog_turns¶
要收集的总对话轮次数。
- update_policy_weights_(policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, **kwargs)[source]¶
在远程工作节点上更新策略权重。
- 参数:
policy_or_weights – 要更新的权重。可以是: - TensorDictModuleBase:一个将提取其权重的策略模块 - TensorDictBase:一个包含权重的 TensorDict - dict:一个常规的包含权重的 dict - None:将尝试通过 _get_server_weights() 从服务器获取权重。
worker_ids – 要更新的工作节点。如果为 None,则更新所有工作节点。
- property weight_updater: WeightUpdaterBase¶
权重更新器实例。
我们可以传递权重更新器,因为它无状态,因此可以序列化。