torchrl.trainers 包¶
trainer 包提供了用于编写可重用训练脚本的实用程序。核心思想是使用一个实现嵌套循环的 trainer,其中外层循环运行数据收集步骤,内层循环运行优化步骤。我们认为这适合多种 RL 训练方案,例如在线策略、离线策略、基于模型和无模型解决方案、离线 RL 等。更具体的情况,例如元 RL 算法可能具有截然不同的训练方案。
trainer.train()
方法可以概括如下:
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop:
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop:
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
trainer 循环中有 10 个钩子:"batch_process"
、"pre_optim_steps"
、"process_optim_batch"
、"post_loss"
、"post_steps"
、"post_optim"
、"pre_steps_log"
、"post_steps_log"
、"post_optim_log"
和 "optimizer"
。它们在注释中指示了它们的应用位置。钩子可分为三类:**数据处理**("batch_process"
和 "process_optim_batch"
)、**日志记录**("pre_steps_log"
、"post_optim_log"
和 "post_steps_log"
)以及**操作**钩子("pre_optim_steps"
、"post_loss"
、"post_optim"
和 "post_steps"
)。
**数据处理**钩子会更新一个数据的 tensordict。钩子的
__call__
方法应该接受一个TensorDict
对象作为输入,并根据某些策略对其进行更新。这类钩子的示例如回放缓冲区扩展(ReplayBufferTrainer.extend
)、数据归一化(包括归一化常数更新)、数据子采样(:class:~torchrl.trainers.BatchSubSampler
)等。**日志记录**钩子接受一个以
TensorDict
形式呈现的数据批次,并将从中检索到的信息写入日志记录器。示例包括LogValidationReward
钩子、奖励日志记录器(LogScalar
)等。钩子应返回一个字典(或 None 值),其中包含要记录的数据。键"log_pbar"
保留给布尔值,指示记录的值是否应显示在训练日志上打印的进度条上。**操作**钩子是执行模型、数据收集器、目标网络更新等特定操作的钩子。例如,使用
UpdateWeights
同步收集器的权重或使用ReplayBufferTrainer.update_priority
更新回放缓冲区的优先级是操作钩子的示例。它们与数据无关(不需要TensorDict
输入),只需要在每次迭代(或每 N 次迭代)时执行一次。
TorchRL 提供的钩子通常继承自一个共同的抽象类 TrainerHookBase
,并都实现了三个基本方法:用于检查点的 state_dict
和 load_state_dict
方法,以及一个用于将钩子注册到 trainer 中默认值的 register
方法。此方法接受 trainer 和模块名称作为输入。例如,以下日志钩子每调用 10 次 "post_optim_log"
就会执行一次:
>>> class LoggingHook(TrainerHookBase):
... def __init__(self):
... self.counter = 0
...
... def register(self, trainer, name):
... trainer.register_module(self, "logging_hook")
... trainer.register_op("post_optim_log", self)
...
... def save_dict(self):
... return {"counter": self.counter}
...
... def load_state_dict(self, state_dict):
... self.counter = state_dict["counter"]
...
... def __call__(self, batch):
... if self.counter % 10 == 0:
... self.counter += 1
... out = {"some_value": batch["some_value"].item(), "log_pbar": False}
... else:
... out = None
... self.counter += 1
... return out
检查点¶
trainer 类和钩子支持检查点,可以通过 torchsnapshot
后端或常规 torch 后端实现。这可以通过全局变量 CKPT_BACKEND
来控制。
$ CKPT_BACKEND=torchsnapshot python script.py
CKPT_BACKEND
默认为 torch
。torchsnapshot 相对于 pytorch 的优势在于它是一个更灵活的 API,支持分布式检查点,并且允许用户将存储在磁盘上的文件中的张量加载到具有物理存储的张量中(这是 pytorch 目前不支持的)。例如,这允许将张量加载到否则无法放入内存的回放缓冲区中,或者从回放缓冲区加载。
在构建 trainer 时,可以提供检查点要写入的路径。对于 torchsnapshot
后端,需要一个目录路径,而 torch
后端需要一个文件路径(通常是 .pt
文件)。
>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
... collector=collector,
... total_frames=total_frames,
... frame_skip=frame_skip,
... loss_module=loss_module,
... optimizer=optimizer,
... save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)
Trainer.train()
方法可用于执行上述循环及其所有钩子,尽管仅使用 Trainer
类来实现其检查点功能也是完全有效的用法。
Trainer 和钩子¶
|
在线 RL SOTA 实现的数据子采样器。 |
|
按给定间隔清除 CUDA 缓存。 |
|
帧计数器钩子。 |
|
用于批次中任何张量值的通用标量日志记录器钩子。 |
|
为一或多个损失组件添加优化器。 |
|
用于 |
|
回放缓冲区钩子提供程序。 |
|
奖励归一化器钩子。 |
|
选择 TensorDict 批次中的键。 |
|
通用 Trainer 类。 |
torchrl Trainer 类的抽象钩子类。 |
|
|
收集器权重更新钩子类。 |
特定于算法的 Trainer(实验性)¶
警告
以下 Trainer 是实验性/原型功能。API 可能在未来版本中发生更改。请报告任何问题或反馈,以帮助改进这些实现!
TorchRL 提供高级、特定于算法的 Trainer,它们将模块化组件组合成完整的训练解决方案,具有合理的默认值和全面的配置选项。
|
PPO(Proximal Policy Optimization)Trainer 实现。 |
PPOTrainer¶
PPOTrainer
提供了一个完整的 PPO 训练解决方案,具有可配置的默认值和基于 Hydra 的全面配置系统。
主要特性
完整的训练流程,包括环境设置、数据收集和优化
使用数据类和 Hydra 的广泛配置系统
内置的奖励、动作和训练统计数据日志记录
基于现有 TorchRL 组件的模块化设计
**最少代码**:仅用约 20 行代码即可完成 SOTA 实现!
警告
这是一项实验性功能。API 可能在未来版本中发生更改。我们欢迎反馈和贡献,以帮助改进此实现!
快速入门 - 命令行界面
# Basic usage - train PPO on Pendulum-v1 with default settings
python sota-implementations/ppo_trainer/train.py
自定义配置
# Override specific parameters via command line
python sota-implementations/ppo_trainer/train.py \
trainer.total_frames=2000000 \
training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
networks.policy_network.num_cells=[256,256] \
optimizer.lr=0.0003
环境切换
# Switch to a different environment and logger
python sota-implementations/ppo_trainer/train.py \
env=gym \
training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
logger=tensorboard
查看所有选项
# View all available configuration options
python sota-implementations/ppo_trainer/train.py --help
配置组
PPOTrainer 的配置组织成逻辑组。
**环境**:
env_cfg__env_name
、env_cfg__backend
、env_cfg__device
**网络**:
actor_network__network__num_cells
、critic_network__module__num_cells
**训练**:
total_frames
、clip_norm
、num_epochs
、optimizer_cfg__lr
**日志记录**:
log_rewards
、log_actions
、log_observations
工作示例
sota-implementations/ppo_trainer/
目录包含一个完整的、可用的 PPO 实现,它演示了 trainer 系统的简洁性和强大功能。
import hydra
from torchrl.trainers.algorithms.configs import *
@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
trainer = hydra.utils.instantiate(cfg.trainer)
trainer.train()
if __name__ == "__main__":
main()
完整的 PPO 训练,在约 20 行代码中实现完全可配置!
配置类
PPOTrainer 使用分层配置系统,包含以下主要配置类。
注意
由于使用了现代类型注解语法,该配置系统需要 Python 3.10+。
**Trainer**:
PPOTrainerConfig
**环境**:
GymEnvConfig
、BatchedEnvConfig
**网络**:
MLPConfig
、TanhNormalModelConfig
**数据**:
TensorDictReplayBufferConfig
、MultiaSyncDataCollectorConfig
**目标**:
PPOLossConfig
**优化器**:
AdamConfig
、AdamWConfig
**日志记录**:
WandbLoggerConfig
、TensorboardLoggerConfig
未来发展
这是计划中的第一个特定于算法的 Trainer。未来的版本将包含:
其他算法:SAC、TD3、DQN、A2C 等
将所有 TorchRL 组件完全集成到配置系统中
增强的配置验证和错误报告
高级 Trainer 的分布式训练支持
请参阅完整的配置系统文档,了解所有可用选项。
Builders¶
|
返回用于离线策略 SOTA 实现的数据收集器。 |
|
在在线策略设置中创建收集器。 |
|
构建 DQN 损失模块。 |
|
使用从 ReplayArgsConfig 构建的配置来构建回放缓冲区。 |
|
构建目标网络权重更新对象。 |
|
给定其组成部分,创建 Trainer 实例。 |
|
使用适当的解析器构造函数构建的 argparse.Namespace 返回一个并行环境。 |
|
运行异步收集器,每个收集器运行同步环境。 |
|
运行同步收集器,每个收集器运行同步环境。 |
|
使用适当的解析器构造函数构建的 argparse.Namespace 返回一个环境创建器。 |
Utils¶
通过将所有反映帧数的参数除以 frame_skip 来更正输入 frame_skip 的参数。 |
|
|
使用随机 rollouts 从环境中收集统计数据(loc 和 scale)。 |
Loggers¶
|
日志记录器的模板。 |
|
极简依赖的 CSV 日志记录器。 |
|
mlflow 日志记录器的包装器。 |
|
Tensorboard 日志记录器的包装器。 |
|
wandb 日志记录器的包装器。 |
|
获取指定 |
|
使用 UUID 和当前日期生成指定实验的 ID(字符串)。 |
Recording utils¶
Recording utils 详细介绍请参阅此处。
|
视频录制器转换。 |
|
TensorDict 录制器。 |
|
一个调用父环境的 render 方法并将像素观察注册到 tensordict 中的转换。 |