快捷方式

使用模型优化入门

作者Vincent Moens

注意

要在 notebook 中运行本教程,请在开头添加一个安装单元格,其中包含:

!pip install tensordict
!pip install torchrl

在 TorchRL 中,我们尝试像 PyTorch 中通常做的那样来处理优化,使用专门设计的仅用于优化模型的损失模块。这种方法有效地将策略的执行与其训练分离开来,使我们能够设计与传统监督学习示例中的训练循环类似的训练循环。

因此,典型的训练循环如下所示:

..code - block::Python

>>> for i in range(n_collections):
...     data = get_next_batch(env, policy)
...     for j in range(n_optim):
...         loss = loss_fn(data)
...         loss.backward()
...         optim.step()

在这个简洁的教程中,您将获得对损失模块的简要概述。由于基本用法 API 通常很简单,因此本教程将保持简短。

RL 目标函数

在 RL 中,创新通常涉及探索新的策略优化方法(即新算法),而不是像其他领域那样专注于新架构。在 TorchRL 中,这些算法被封装在损失模块中。损失模块协调算法的各个组件,并产生一组可以反向传播以训练相应组件的损失值。

在本教程中,我们将以一个流行的离策略算法 DDPG 为例,DDPG

要构建损失模块,唯一需要的是一组定义为 :class:`~tensordict.nn.TensorDictModule` 的网络。大多数时候,其中一个模块是策略。也可能需要其他辅助网络,例如 Q 值网络或某种形式的 critics。让我们看看在实践中它是如何工作的:DDPG 需要一个从观察空间到动作空间的确定性映射,以及一个预测状态-动作对值的价值网络。DDPG 损失将尝试找到输出能够最大化给定状态的动作的策略参数。

要构建损失,我们需要 actor 和 value 网络。如果它们是根据 DDPG 的预期构建的,那么它们就是我们获得可训练损失模块所需的一切。

from torchrl.envs import GymEnv

env = GymEnv("Pendulum-v1")

from torchrl.modules import Actor, MLP, ValueOperator
from torchrl.objectives import DDPGLoss

n_obs = env.observation_spec["observation"].shape[-1]
n_act = env.action_spec.shape[-1]
actor = Actor(MLP(in_features=n_obs, out_features=n_act, num_cells=[32, 32]))
value_net = ValueOperator(
    MLP(in_features=n_obs + n_act, out_features=1, num_cells=[32, 32]),
    in_keys=["observation", "action"],
)

ddpg_loss = DDPGLoss(actor_network=actor, value_network=value_net)

就是这样!我们的损失模块现在可以与来自环境的数据一起运行(为了专注于损失功能,我们省略了探索、存储和其他功能)。

rollout = env.rollout(max_steps=100, policy=actor)
loss_vals = ddpg_loss(rollout)
print(loss_vals)

LossModule 的输出

如您所见,我们从损失中收到的值不是一个标量,而是一个包含多个损失的字典。

原因很简单:因为可能同时训练多个网络,并且由于某些用户可能希望将每个模块的优化分开进行,因此 TorchRL 的目标将返回包含各种损失组件的字典。

此格式还允许我们将元数据与损失值一起传递。总的来说,我们确保只有损失值是可微分的,这样您就可以简单地将字典中的值相加来获得总损失。如果您想确保您完全控制发生了什么,您可以只对键以 "loss_" 前缀开头的条目求和。

total_loss = 0
for key, val in loss_vals.items():
    if key.startswith("loss_"):
        total_loss += val

训练 LossModule

鉴于所有这些,训练模块与任何其他训练循环中的操作并没有太大区别。因为损失模块包装了这些模块,所以获取可训练参数列表的最简单方法是查询 parameters() 方法。

我们将需要一个优化器(或者如果您选择每个模块一个优化器)。

from torch.optim import Adam

optim = Adam(ddpg_loss.parameters())
total_loss.backward()

以下项目通常会在您的训练循环中找到:

optim.step()
optim.zero_grad()

进一步考虑:目标参数

另一个需要考虑的重要方面是像 DDPG 这样的离策略算法中目标参数的存在。目标参数通常代表参数随时间的延迟或平滑版本,它们在策略训练期间在值估计中起着至关重要的作用。与使用当前值网络参数配置相比,使用目标参数进行策略训练通常效率更高。通常,目标参数的管理由损失模块处理,让用户无需直接担心。但是,根据具体要求,用户仍有责任根据需要更新这些值。TorchRL 提供了几个更新器,即 HardUpdateSoftUpdate,无需深入了解损失模块的底层机制即可轻松实例化。

from torchrl.objectives import SoftUpdate

updater = SoftUpdate(ddpg_loss, eps=0.99)

在您的训练循环中,您需要在每个优化步骤或每个收集步骤中更新目标参数。

updater.step()

这就是开始使用损失模块所需了解的全部内容!

要进一步探讨该主题,请查看:

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源