Trainer¶
- class torchrl.trainers.Trainer(*args, **kwargs)[source]¶
一个通用的 Trainer 类。
Trainer 负责收集数据和训练模型。为了使该类尽可能通用,Trainer 不会构建任何其特定操作:它们都必须在训练循环的特定点进行挂钩。
要构建一个 Trainer,需要一个可迭代的数据源(一个
collector
)、一个损失模块和一个优化器。- 参数:
collector (Sequence[TensorDictBase]) – 一个可迭代对象,以 TensorDict 形式返回数据批次,形状为 [batch x time steps]。
total_frames (int) – 训练期间要收集的总帧数。
loss_module (LossModule) – 一个模块,用于读取 TensorDict 批次(可能从回放缓冲区采样)并返回一个损失 TensorDict,其中每个键都指向一个不同的损失组件。
optimizer (optim.Optimizer) – 一个用于训练模型参数的优化器。
logger (Logger, optional) – 一个将处理日志记录的 Logger。
optim_steps_per_batch (int) – 每个数据批次的优化步数。Trainer 的工作方式如下:主循环收集数据批次(epoch 循环),子循环(训练循环)在两次数据收集之间执行模型更新。
clip_grad_norm (bool, optional) – 如果为 True,则将根据模型参数的总范数来裁剪梯度。如果为 False,则所有偏导数都将被限制在 (-clip_norm, clip_norm) 范围内。默认为
True
。clip_norm (Number, optional) – 用于裁剪梯度的值。默认为 None(不裁剪范数)。
progress_bar (bool, optional) – 如果为 True,则使用 tqdm 显示进度条。如果未安装 tqdm,此选项将不起任何作用。默认为
True
seed (int, optional) – 用于 collector、pytorch 和 numpy 的种子。默认为
None
。save_trainer_interval (int, optional) – Trainer 保存到磁盘的频率,以帧数计。默认为 10000。
log_interval (int, optional) – 值记录的频率,以帧数计。默认为 10000。
save_trainer_file (path, optional) – 保存 trainer 的路径。默认为 None(不保存)