快捷方式

RayLLMCollector

class torchrl.collectors.llm.RayLLMCollector(env: EnvBase | Callable[[], EnvBase], *, policy: Callable[[TensorDictBase], TensorDictBase] | None = None, policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]] | None = None, dialog_turns_per_batch: int, total_dialog_turns: int = - 1, yield_only_last_steps: bool | None = None, yield_completed_trajectories: bool | None = None, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, async_envs: bool | None = None, replay_buffer: ReplayBuffer | None = None, reset_at_each_iter: bool = False, flatten_data: bool | None = None, weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, ray_init_config: dict[str, Any] | None = None, remote_config: dict[str, Any] | None = None, track_policy_version: bool | PolicyVersion = False, sync_iter: bool = True, verbose: bool = False, num_cpus: int | None = None, num_gpus: int | None = None)[source]

一个轻量级的 Ray 实现的 LLM Collector,可以远程扩展和采样。

参数:

env (EnvBaseEnvBase 构造函数) – 用于数据收集的环境。

关键字参数:
  • policy (Callable[[TensorDictBase], TensorDictBase]) – 用于数据收集的策略。

  • policy_factory (Callable[[], Callable], optional) – 一个返回策略实例的可调用对象。这与 policy 参数互斥。

  • dialog_turns_per_batch (int) – 一个关键字参数,表示批次中的总元素数量。

  • total_dialog_turns (int) – 一个关键字参数,表示收集器在其生命周期内返回的总对话轮次数。

  • yield_only_last_steps (bool, optional) – 是否生成轨迹的每一步,还是只生成最后(完成)的步骤。

  • yield_completed_trajectories (bool, optional) – 是生成具有给定步数的 rollout 批次,还是生成单个、完整的轨迹。

  • postproc (Callable, optional) – 一个后处理转换。

  • async_envs (bool, optional) – 如果为 True,环境将异步运行。

  • replay_buffer (ReplayBuffer, optional) – 如果提供,收集器将不会生成 tensordicts,而是填充缓冲区。

  • reset_at_each_iter (bool, optional) – 如果为 True,环境将在每次迭代时重置。

  • flatten_data (bool, optional) – 如果为 True,收集器将在返回前展平收集到的数据。

  • weight_updater (WeightUpdaterBase构造函数, optional) – WeightUpdaterBase 实例或其子类,负责在远程推理工作器上更新策略权重。

  • ray_init_config (dict[str, Any], optional) – 传递给 ray.init() 的关键字参数。

  • remote_config (dict[str, Any], optional) – 传递给 cls.as_remote() 的关键字参数。

  • num_cpus (int, optional) – Actor 的 CPU 数量。默认为 None (从 remote_config 获取)。

  • num_gpus (int, optional) – Actor 的 GPU 数量。默认为 None (从 remote_config 获取)。

  • sync_iter (bool, optional) –

    如果为 True,收集器生成的项目将被同步到本地进程。如果为 False,收集器将在生成之间收集下一批数据。这在通过 start() 方法收集数据时无效。例如

    >>> collector = RayLLMCollector(..., sync_iter=True)
    >>> for data in collector:  # blocking
    ...     # expensive operation - collector is idle
    >>> collector = RayLLMCollector(..., sync_iter=False)
    >>> for data in collector:  # non-blocking
    ...     # expensive operation - collector is collecting data
    

    这在某种程度上等同于使用 MultiSyncDataCollector (sync_iter=True) 或 MultiAsyncDataCollector (sync_iter=False) 。默认为 True

  • verbose (bool, optional) – 如果为 True,收集器将打印进度信息。默认为 False

classmethod as_remote(remote_config: dict[str, Any] | None = None)

创建一个远程 ray 类的实例。

参数:
  • cls (Python Class) – 要远程实例化的类。

  • remote_config (dict) – 为此类保留的 CPU 核心数量。

返回:

一个创建 ray 远程类实例的函数。

async_shutdown(timeout=None)[source]

异步关闭收集器。

property dialog_turns_per_batch: int

每个批次的对话轮次数。

get_policy_model()

获取策略模型。

RayLLMCollector 使用此方法来获取用于权重更新的远程 LLM 实例。

返回:

策略模型实例

get_policy_version() str | int | None

获取当前策略版本。

此方法用于支持 Ray actor 中的远程调用,因为属性无法通过 Ray 的 RPC 机制直接访问。

返回:

当前版本号(整数)或 UUID(字符串),如果版本跟踪已禁用则为 None。

increment_version()[source]

增加策略版本。

init_updater(*args, **kwargs)[source]

使用自定义参数初始化权重更新器。

此方法调用远程收集器上的 init_updater。

参数:
  • *args – 用于权重更新器初始化的位置参数

  • **kwargs – 用于权重更新器初始化的关键字参数

is_initialized() bool

检查收集器是否已初始化并准备好。

返回:

如果收集器已初始化并准备好收集数据,则为 True。

返回类型:

布尔值

iterator() Iterator[TensorDictBase]

迭代 DataCollector。

Yields: 包含轨迹 (块) 的 TensorDictBase 对象

load_state_dict(state_dict: OrderedDict, **kwargs) None

在环境和策略上加载 state_dict。

参数:

state_dict (OrderedDict) – 包含 “policy_state_dict”"env_state_dict" 字段的有序字典。

next() None[source]

从收集器获取下一批数据。

返回:

None,因为数据直接写入回放缓冲区。

pause()

上下文管理器,如果收集器正在自由运行,则暂停收集器。

property policy_version: str | int | None

策略的当前版本。

返回:

当前版本号(整数)或 UUID(字符串),如果版本跟踪已禁用则为 None。

reset(index=None, **kwargs) None

将环境重置到新的初始状态。

property rollout: Callable[[], TensorDictBase]

返回 rollout 函数。

set_seed(seed: int, static_seed: bool = False) int

设置 DataCollector 中存储的环境的种子。

参数:
  • seed (int) – 用于环境的种子整数。

  • static_seed (bool, optional) – 如果 True,种子不会递增。默认为 False

返回:

输出种子。当 DataCollector 包含多个环境时,这很有用,因为种子会为每个环境递增。结果种子是最后一个环境的种子。

示例

>>> from torchrl.envs import ParallelEnv
>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_fn = lambda: GymEnv("Pendulum-v1")
>>> env_fn_parallel = ParallelEnv(6, env_fn)
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100)
>>> out_seed = collector.set_seed(1)  # out_seed = 6
shutdown()[source]

关闭收集器。

start()[source]

在后台线程中启动收集器。

state_dict() OrderedDict

返回数据收集器的本地 state_dict(环境和策略)。

返回:

包含 "policy_state_dict"“env_state_dict” 字段的有序字典。

property total_dialog_turns

要收集的总对话轮次数。

update_policy_weights_(policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, **kwargs)[source]

在远程工作器上更新策略权重。

参数:
  • policy_or_weights – 要更新的权重。可以是: - TensorDictModuleBase:一个将提取其权重的策略模块 - TensorDictBase:一个包含权重的 TensorDict - dict:一个包含权重的常规 dict - None:将尝试使用 _get_server_weights() 从服务器获取权重。

  • worker_ids – 要更新的工作器。如果为 None,则更新所有工作器。

property weight_updater: WeightUpdaterBase

权重更新器实例。

我们可以传递权重更新器,因为它是无状态的,因此是可序列化的。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源