快捷方式

torchrl.trainers 包

trainer 包提供了用于编写可重用训练脚本的实用程序。核心思想是使用一个实现嵌套循环的 trainer,其中外层循环运行数据收集步骤,内层循环运行优化步骤。我们认为这适合多种 RL 训练方案,例如在线策略、离线策略、基于模型和无模型解决方案、离线 RL 等。更具体的情况,例如元 RL 算法可能具有截然不同的训练方案。

trainer.train() 方法可以概括如下:

Trainer 循环
        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

     >>> for batch in collector:
     ...     batch = self._process_batch_hook(batch)  # "batch_process"
     ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
     ...     self._pre_optim_hook()  # "pre_optim_steps"
     ...     for j in range(self.optim_steps_per_batch):
     ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
     ...         losses = self.loss_module(sub_batch)
     ...         self._post_loss_hook(sub_batch)  # "post_loss"
     ...         self.optimizer.step()
     ...         self.optimizer.zero_grad()
     ...         self._post_optim_hook()  # "post_optim"
     ...         self._post_optim_log(sub_batch)  # "post_optim_log"
     ...     self._post_steps_hook()  # "post_steps"
     ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

trainer 循环中有 10 个钩子:"batch_process""pre_optim_steps""process_optim_batch""post_loss""post_steps""post_optim""pre_steps_log""post_steps_log""post_optim_log""optimizer"。它们在注释中指示了它们的应用位置。钩子可分为三类:**数据处理**("batch_process""process_optim_batch")、**日志记录**("pre_steps_log""post_optim_log""post_steps_log")以及**操作**钩子("pre_optim_steps""post_loss""post_optim""post_steps")。

  • **数据处理**钩子会更新一个数据的 tensordict。钩子的 __call__ 方法应该接受一个 TensorDict 对象作为输入,并根据某些策略对其进行更新。这类钩子的示例如回放缓冲区扩展(ReplayBufferTrainer.extend)、数据归一化(包括归一化常数更新)、数据子采样(:class:~torchrl.trainers.BatchSubSampler)等。

  • **日志记录**钩子接受一个以 TensorDict 形式呈现的数据批次,并将从中检索到的信息写入日志记录器。示例包括 LogValidationReward 钩子、奖励日志记录器(LogScalar)等。钩子应返回一个字典(或 None 值),其中包含要记录的数据。键 "log_pbar" 保留给布尔值,指示记录的值是否应显示在训练日志上打印的进度条上。

  • **操作**钩子是执行模型、数据收集器、目标网络更新等特定操作的钩子。例如,使用 UpdateWeights 同步收集器的权重或使用 ReplayBufferTrainer.update_priority 更新回放缓冲区的优先级是操作钩子的示例。它们与数据无关(不需要 TensorDict 输入),只需要在每次迭代(或每 N 次迭代)时执行一次。

TorchRL 提供的钩子通常继承自一个共同的抽象类 TrainerHookBase,并都实现了三个基本方法:用于检查点的 state_dictload_state_dict 方法,以及一个用于将钩子注册到 trainer 中默认值的 register 方法。此方法接受 trainer 和模块名称作为输入。例如,以下日志钩子每调用 10 次 "post_optim_log" 就会执行一次:

>>> class LoggingHook(TrainerHookBase):
...     def __init__(self):
...         self.counter = 0
...
...     def register(self, trainer, name):
...         trainer.register_module(self, "logging_hook")
...         trainer.register_op("post_optim_log", self)
...
...     def save_dict(self):
...         return {"counter": self.counter}
...
...     def load_state_dict(self, state_dict):
...         self.counter = state_dict["counter"]
...
...     def __call__(self, batch):
...         if self.counter % 10 == 0:
...             self.counter += 1
...             out = {"some_value": batch["some_value"].item(), "log_pbar": False}
...         else:
...             out = None
...         self.counter += 1
...         return out

检查点

trainer 类和钩子支持检查点,可以通过 torchsnapshot 后端或常规 torch 后端实现。这可以通过全局变量 CKPT_BACKEND 来控制。

$ CKPT_BACKEND=torchsnapshot python script.py

CKPT_BACKEND 默认为 torch。torchsnapshot 相对于 pytorch 的优势在于它是一个更灵活的 API,支持分布式检查点,并且允许用户将存储在磁盘上的文件中的张量加载到具有物理存储的张量中(这是 pytorch 目前不支持的)。例如,这允许将张量加载到否则无法放入内存的回放缓冲区中,或者从回放缓冲区加载。

在构建 trainer 时,可以提供检查点要写入的路径。对于 torchsnapshot 后端,需要一个目录路径,而 torch 后端需要一个文件路径(通常是 .pt 文件)。

>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
...     collector=collector,
...     total_frames=total_frames,
...     frame_skip=frame_skip,
...     loss_module=loss_module,
...     optimizer=optimizer,
...     save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)

Trainer.train() 方法可用于执行上述循环及其所有钩子,尽管仅使用 Trainer 类来实现其检查点功能也是完全有效的用法。

Trainer 和钩子

BatchSubSampler(batch_size[, sub_traj_len, ...])

在线 RL SOTA 实现的数据子采样器。

ClearCudaCache(interval)

按给定间隔清除 CUDA 缓存。

CountFramesLog(*args, **kwargs)

帧计数器钩子。

LogScalar([key, logname, log_pbar, ...])

用于批次中任何张量值的通用标量日志记录器钩子。

OptimizerHook(optimizer[, loss_components])

为一或多个损失组件添加优化器。

LogValidationReward(*, record_interval, ...)

用于 Trainer 的记录器钩子。

ReplayBufferTrainer(replay_buffer[, ...])

回放缓冲区钩子提供程序。

RewardNormalizer([decay, scale, eps, ...])

奖励归一化器钩子。

SelectKeys(keys)

选择 TensorDict 批次中的键。

Trainer(*args, **kwargs)

通用 Trainer 类。

TrainerHookBase()

torchrl Trainer 类的抽象钩子类。

UpdateWeights(collector, update_weights_interval)

收集器权重更新钩子类。

特定于算法的 Trainer(实验性)

警告

以下 Trainer 是实验性/原型功能。API 可能在未来版本中发生更改。请报告任何问题或反馈,以帮助改进这些实现!

TorchRL 提供高级、特定于算法的 Trainer,它们将模块化组件组合成完整的训练解决方案,具有合理的默认值和全面的配置选项。

PPOTrainer(*args, **kwargs)

PPO(Proximal Policy Optimization)Trainer 实现。

PPOTrainer

PPOTrainer 提供了一个完整的 PPO 训练解决方案,具有可配置的默认值和基于 Hydra 的全面配置系统。

主要特性

  • 完整的训练流程,包括环境设置、数据收集和优化

  • 使用数据类和 Hydra 的广泛配置系统

  • 内置的奖励、动作和训练统计数据日志记录

  • 基于现有 TorchRL 组件的模块化设计

  • **最少代码**:仅用约 20 行代码即可完成 SOTA 实现!

警告

这是一项实验性功能。API 可能在未来版本中发生更改。我们欢迎反馈和贡献,以帮助改进此实现!

快速入门 - 命令行界面

# Basic usage - train PPO on Pendulum-v1 with default settings
python sota-implementations/ppo_trainer/train.py

自定义配置

# Override specific parameters via command line
python sota-implementations/ppo_trainer/train.py \
    trainer.total_frames=2000000 \
    training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
    networks.policy_network.num_cells=[256,256] \
    optimizer.lr=0.0003

环境切换

# Switch to a different environment and logger
python sota-implementations/ppo_trainer/train.py \
    env=gym \
    training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
    logger=tensorboard

查看所有选项

# View all available configuration options
python sota-implementations/ppo_trainer/train.py --help

配置组

PPOTrainer 的配置组织成逻辑组。

  • **环境**:env_cfg__env_nameenv_cfg__backendenv_cfg__device

  • **网络**:actor_network__network__num_cellscritic_network__module__num_cells

  • **训练**:total_framesclip_normnum_epochsoptimizer_cfg__lr

  • **日志记录**:log_rewardslog_actionslog_observations

工作示例

sota-implementations/ppo_trainer/ 目录包含一个完整的、可用的 PPO 实现,它演示了 trainer 系统的简洁性和强大功能。

import hydra
from torchrl.trainers.algorithms.configs import *

@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
    trainer = hydra.utils.instantiate(cfg.trainer)
    trainer.train()

if __name__ == "__main__":
    main()

完整的 PPO 训练,在约 20 行代码中实现完全可配置!

配置类

PPOTrainer 使用分层配置系统,包含以下主要配置类。

注意

由于使用了现代类型注解语法,该配置系统需要 Python 3.10+。

未来发展

这是计划中的第一个特定于算法的 Trainer。未来的版本将包含:

  • 其他算法:SAC、TD3、DQN、A2C 等

  • 将所有 TorchRL 组件完全集成到配置系统中

  • 增强的配置验证和错误报告

  • 高级 Trainer 的分布式训练支持

请参阅完整的配置系统文档,了解所有可用选项。

Builders

make_collector_offpolicy(make_env, ...[, ...])

返回用于离线策略 SOTA 实现的数据收集器。

make_collector_onpolicy(make_env, ...[, ...])

在在线策略设置中创建收集器。

make_dqn_loss(model, cfg)

构建 DQN 损失模块。

make_replay_buffer(device, cfg)

使用从 ReplayArgsConfig 构建的配置来构建回放缓冲区。

make_target_updater(cfg, loss_module)

构建目标网络权重更新对象。

make_trainer(collector, loss_module[, ...])

给定其组成部分,创建 Trainer 实例。

parallel_env_constructor(cfg, **kwargs)

使用适当的解析器构造函数构建的 argparse.Namespace 返回一个并行环境。

sync_async_collector(env_fns, env_kwargs[, ...])

运行异步收集器,每个收集器运行同步环境。

sync_sync_collector(env_fns, env_kwargs[, ...])

运行同步收集器,每个收集器运行同步环境。

transformed_env_constructor(cfg[, ...])

使用适当的解析器构造函数构建的 argparse.Namespace 返回一个环境创建器。

Utils

correct_for_frame_skip(cfg)

通过将所有反映帧数的参数除以 frame_skip 来更正输入 frame_skip 的参数。

get_stats_random_rollout(cfg[, ...])

使用随机 rollouts 从环境中收集统计数据(loc 和 scale)。

Loggers

Logger(exp_name, log_dir)

日志记录器的模板。

csv.CSVLogger(exp_name[, log_dir, ...])

极简依赖的 CSV 日志记录器。

mlflow.MLFlowLogger(exp_name, tracking_uri)

mlflow 日志记录器的包装器。

tensorboard.TensorboardLogger(exp_name[, ...])

Tensorboard 日志记录器的包装器。

wandb.WandbLogger(*args, **kwargs)

wandb 日志记录器的包装器。

get_logger(logger_type, logger_name, ...)

获取指定 logger_type 的日志记录器实例。

generate_exp_name(model_name, experiment_name)

使用 UUID 和当前日期生成指定实验的 ID(字符串)。

Recording utils

Recording utils 详细介绍请参阅此处

VideoRecorder(logger, tag[, in_keys, skip, ...])

视频录制器转换。

TensorDictRecorder(out_file_base[, ...])

TensorDict 录制器。

PixelRenderTransform([out_keys, preproc, ...])

一个调用父环境的 render 方法并将像素观察注册到 tensordict 中的转换。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源