快捷方式

torchrl.trainers 包

trainer 包提供了编写可重用训练脚本的实用程序。核心思想是使用一个实现嵌套循环的 trainer,外层循环执行数据收集步骤,内层循环执行优化步骤。我们相信这适用于多种强化学习训练方案,如同策略、离策略、基于模型和无模型解决方案、离线强化学习等。更具体的情况,如元强化学习算法,其训练方案可能存在显著差异。

trainer.train() 方法可以概括如下:

Trainer 循环
        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

     >>> for batch in collector:
     ...     batch = self._process_batch_hook(batch)  # "batch_process"
     ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
     ...     self._pre_optim_hook()  # "pre_optim_steps"
     ...     for j in range(self.optim_steps_per_batch):
     ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
     ...         losses = self.loss_module(sub_batch)
     ...         self._post_loss_hook(sub_batch)  # "post_loss"
     ...         self.optimizer.step()
     ...         self.optimizer.zero_grad()
     ...         self._post_optim_hook()  # "post_optim"
     ...         self._post_optim_log(sub_batch)  # "post_optim_log"
     ...     self._post_steps_hook()  # "post_steps"
     ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

Trainer 循环中有 10 个钩子可以使用:"batch_process""pre_optim_steps""process_optim_batch""post_loss""post_steps""post_optim""pre_steps_log""post_steps_log""post_optim_log""optimizer"。它们在注释中标明了应用位置。钩子可分为 3 类:**数据处理** ("batch_process""process_optim_batch")、**日志记录** ("pre_steps_log""post_optim_log""post_steps_log") 和 **操作** 钩子 ("pre_optim_steps""post_loss""post_optim""post_steps")。

  • 数据处理 钩子会更新一个 tensordict 数据。钩子的 __call__ 方法应接受一个 TensorDict 对象作为输入,并根据某种策略进行更新。此类钩子的示例包括回放缓冲区扩展 (ReplayBufferTrainer.extend)、数据归一化(包括归一化常数更新)、数据子采样 (:class:~torchrl.trainers.BatchSubSampler) 等。

  • 日志记录 钩子接收以 TensorDict 形式表示的数据批次,并从该数据中检索信息写入日志记录器。示例包括 LogValidationReward 钩子、奖励记录器 (LogScalar) 等。钩子应返回一个字典(或 None 值),其中包含要记录的数据。键 "log_pbar" 保留给布尔值,指示记录的值是否应显示在训练日志上打印的进度条上。

  • 操作 钩子是执行特定操作的钩子,例如在模型、数据收集器、目标网络更新等之上执行。例如,使用 UpdateWeights 同步收集器权重或使用 ReplayBufferTrainer.update_priority 更新回放缓冲区的优先级是操作钩子的示例。它们与数据无关(不需要 TensorDict 输入),只需在每次迭代(或每 N 次迭代)执行一次。

TorchRL 提供的钩子通常继承自一个共同的抽象类 TrainerHookBase,并且都实现了三个基本方法:用于检查点设置的 state_dictload_state_dict 方法,以及一个在 trainer 中将钩子注册为默认值的 register 方法。此方法接受 trainer 和模块名称作为输入。例如,以下日志记录钩子在每次调用 "post_optim_log" 的 10 次后执行:

>>> class LoggingHook(TrainerHookBase):
...     def __init__(self):
...         self.counter = 0
...
...     def register(self, trainer, name):
...         trainer.register_module(self, "logging_hook")
...         trainer.register_op("post_optim_log", self)
...
...     def save_dict(self):
...         return {"counter": self.counter}
...
...     def load_state_dict(self, state_dict):
...         self.counter = state_dict["counter"]
...
...     def __call__(self, batch):
...         if self.counter % 10 == 0:
...             self.counter += 1
...             out = {"some_value": batch["some_value"].item(), "log_pbar": False}
...         else:
...             out = None
...         self.counter += 1
...         return out

检查点

trainer 类和钩子支持检查点,这可以通过 torchsnapshot 后端或常规的 torch 后端来实现。这可以通过全局变量 CKPT_BACKEND 控制。

$ CKPT_BACKEND=torchsnapshot python script.py

CKPT_BACKEND 默认为 torch。torchsnapshot 相对于 pytorch 的优势在于它是一个更灵活的 API,支持分布式检查点,并且允许用户将存储在磁盘上的文件中的张量加载到具有物理存储的张量中(这是 pytorch 目前不支持的)。这允许例如将张量加载到或从一个否则无法放入内存的回放缓冲区中。

在构建 trainer 时,可以提供检查点要写入的路径。对于 torchsnapshot 后端,需要一个目录路径,而 torch 后端需要一个文件路径(通常是 .pt 文件)。

>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
...     collector=collector,
...     total_frames=total_frames,
...     frame_skip=frame_skip,
...     loss_module=loss_module,
...     optimizer=optimizer,
...     save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)

Trainer.train() 方法可用于执行具有所有钩子的上述循环,尽管仅使用 Trainer 类来执行其检查点功能也是完全有效的用法。

Trainer 和钩子

BatchSubSampler(batch_size[, sub_traj_len, ...])

在线强化学习 sota 实现的数据子采样器。

ClearCudaCache(interval)

以给定的间隔清除 cuda 缓存。

CountFramesLog(*args, **kwargs)

一个帧计数器钩子。

LogScalar([logname, log_pbar, reward_key])

奖励记录器钩子。

OptimizerHook(optimizer[, loss_components])

为一或多个损失分量添加一个优化器。

LogValidationReward(*, record_interval, ...)

Trainer 的记录器钩子。

ReplayBufferTrainer(replay_buffer[, ...])

回放缓冲区钩子提供程序。

RewardNormalizer([decay, scale, eps, ...])

奖励归一化器钩子。

SelectKeys(keys)

在 TensorDict 批次中选择键。

Trainer(*args, **kwargs)

一个通用的 Trainer 类。

TrainerHookBase()

torchrl Trainer 类的一个抽象钩子类。

UpdateWeights(collector, update_weights_interval)

一个收集器权重更新钩子类。

构建器

make_collector_offpolicy(make_env, ...[, ...])

为离策略 sota 实现返回一个数据收集器。

make_collector_onpolicy(make_env, ...[, ...])

在同策略设置中创建一个收集器。

make_dqn_loss(model, cfg)

构建 DQN 损失模块。

make_replay_buffer(device, cfg)

使用从 ReplayArgsConfig 构建的配置构建回放缓冲区。

make_target_updater(cfg, loss_module)

构建目标网络权重更新对象。

make_trainer(collector, loss_module[, ...])

给定组成部分,创建 Trainer 实例。

parallel_env_constructor(cfg, **kwargs)

从使用适当解析器构造函数构建的 argparse.Namespace 返回一个并行环境。

sync_async_collector(env_fns, env_kwargs[, ...])

运行异步收集器,每个收集器运行同步环境。

sync_sync_collector(env_fns, env_kwargs[, ...])

运行同步收集器,每个收集器运行同步环境。

transformed_env_constructor(cfg[, ...])

从使用适当解析器构造函数构建的 argparse.Namespace 返回一个环境创建器。

工具

correct_for_frame_skip(cfg)

通过将所有反映帧数的参数除以 frame_skip 来修正输入 frame_skip 的参数。

get_stats_random_rollout(cfg[, ...])

使用随机 rollout 从环境中收集统计数据(loc 和 scale)。

日志记录器

Logger(exp_name, log_dir)

日志记录器的模板。

csv.CSVLogger(exp_name[, log_dir, ...])

最少依赖的 CSV 日志记录器。

mlflow.MLFlowLogger(exp_name, tracking_uri)

mlflow 日志记录器的包装器。

tensorboard.TensorboardLogger(exp_name[, ...])

Tensorboard 日志记录器的包装器。

wandb.WandbLogger(*args, **kwargs)

wandb 日志记录器的包装器。

get_logger(logger_type, logger_name, ...)

获取提供商 logger_type 的日志记录器实例。

generate_exp_name(model_name, experiment_name)

使用 UUID 和当前日期为描述的实验生成 ID(字符串)。

录制工具

录制工具的详细信息请参见 这里

VideoRecorder(logger, tag[, in_keys, skip, ...])

视频录制器转换。

TensorDictRecorder(out_file_base[, ...])

TensorDict 录制器。

PixelRenderTransform([out_keys, preproc, ...])

一个调用父环境的 render 方法并将像素观察注册到 tensordict 中的转换。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源