评价此页

循环 DQN:训练循环策略#

创建日期:2023年11月8日 | 最后更新:2025年1月27日 | 最后验证:未验证

作者Vincent Moens

您将学到什么
  • 如何在 TorchRL 的 Actor 中合并 RNN

  • 如何将该基于记忆的策略与重放缓冲区和损失模块结合使用

先决条件
  • PyTorch v2.0.0

  • gym[mujoco]

  • tqdm

概述#

基于记忆的策略不仅在观察结果是部分可观测时至关重要,而且在需要考虑时间维度以做出明智决策时也至关重要。

循环神经网络长期以来一直是基于记忆策略的热门工具。其思路是在两个连续步骤之间保持循环状态,并将其与当前观察结果一起作为策略的输入。

本教程展示了如何使用 TorchRL 将 RNN 合并到策略中。

主要学习内容

  • 在 TorchRL 的 Actor 中合并 RNN;

  • 将该基于记忆的策略与重放缓冲区和损失模块结合使用。

在 TorchRL 中使用 RNN 的核心思想是将 TensorDict 作为从一个步骤到另一个步骤的隐藏状态的数据载体。我们将构建一个策略,从当前的 TensorDict 读取之前的循环状态,并将当前的循环状态写入下一个状态的 TensorDict 中。

Data collection with a recurrent policy

如图所示,我们的环境使用零填充的循环状态填充 TensorDict,策略将其与观察结果一起读取以生成动作,并产生将用于下一步的循环状态。当调用 step_mdp() 函数时,来自下一个状态的循环状态会被带入当前的 TensorDict。让我们看看这在实践中是如何实现的。

如果您在 Google Colab 中运行此代码,请确保安装以下依赖项

!pip3 install torchrl
!pip3 install gym[mujoco]
!pip3 install tqdm

设置#

import torch
import tqdm
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.envs import (
    Compose,
    ExplorationType,
    GrayScale,
    InitTracker,
    ObservationNorm,
    Resize,
    RewardScaling,
    set_exploration_type,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule
from torchrl.objectives import DQNLoss, SoftUpdate

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

环境#

像往常一样,第一步是构建我们的环境:它有助于我们定义问题并据此构建策略网络。在本教程中,我们将运行一个基于像素的 CartPole gym 环境实例,并进行一些自定义转换:转换为灰度图、调整大小为 84x84、缩小奖励并归一化观察结果。

注意

StepCounter 转换是辅助性的。由于 CartPole 任务的目标是尽可能延长轨迹,因此计算步数可以帮助我们跟踪策略的性能。

对于本教程,有两个转换很重要

  • InitTracker 将通过在 TensorDict 中添加一个 "is_init" 布尔掩码来标记对 reset() 的调用,该掩码将跟踪哪些步骤需要重置 RNN 隐藏状态。

  • TensorDictPrimer 转换在技术上稍微复杂一些。使用 RNN 策略并不一定需要它。但是,它会指示环境(以及随后的收集器)预期存在一些额外的键。添加后,调用 env.reset() 将使用零填充张量填充 Primer 中指定的条目。由于知道策略需要这些张量,收集器将在收集期间传递它们。最终,我们将把隐藏状态存储在重放缓冲区中,这将有助于我们在损失模块中引导 RNN 计算(否则将以 0 初始化)。总而言之:不包含此转换不会对我们的策略训练产生巨大影响,但它会导致循环键从收集的数据和重放缓冲区中消失,这反过来会导致训练效果略逊一筹。幸运的是,我们建议的 LSTMModule 配备了一个辅助方法来为我们构建该转换,所以我们可以等到构建它时再处理!

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True, device=device),
    Compose(
        ToTensorImage(),
        GrayScale(),
        Resize(84, 84),
        StepCounter(),
        InitTracker(),
        RewardScaling(loc=0.0, scale=0.1),
        ObservationNorm(standard_normal=True, in_keys=["pixels"]),
    ),
)

像往常一样,我们需要手动初始化归一化常数

env.transform[-1].init_stats(1000, reduce_dim=[0, 1, 2], cat_dim=0, keep_dims=[0])
td = env.reset()

策略 (Policy)#

我们的策略将包含 3 个组件:ConvNet 主干、LSTMModule 记忆层以及一个浅层的 MLP 块,它将把 LSTM 输出映射到动作值。

卷积网络#

我们构建了一个侧接 torch.nn.AdaptiveAvgPool2d 的卷积网络,它将输出压缩为一个大小为 64 的向量。ConvNet 可以协助我们完成此操作。

feature = Mod(
    ConvNet(
        num_cells=[32, 32, 64],
        squeeze_output=True,
        aggregator_class=nn.AdaptiveAvgPool2d,
        aggregator_kwargs={"output_size": (1, 1)},
        device=device,
    ),
    in_keys=["pixels"],
    out_keys=["embed"],
)

我们在数据批次上执行第一个模块,以获取输出向量的大小

n_cells = feature(env.reset())["embed"].shape[-1]

LSTM 模块#

TorchRL 提供了一个专门的 LSTMModule 类来将 LSTM 合并到您的代码库中。它是 TensorDictModuleBase 的子类:因此,它有一组 in_keysout_keys,用于指示在模块执行期间应读取和写入/更新哪些值。该类带有可自定义的预定义值,以方便其构造。

注意

使用限制:该类支持几乎所有 LSTM 功能,如 dropout 或多层 LSTM。然而,为了遵守 TorchRL 的约定,此 LSTM 必须将 batch_first 属性设置为 True,这在 PyTorch 中不是默认设置。不过,我们的 LSTMModule 改变了此默认行为,因此我们可以直接进行原生调用。

此外,LSTM 不能将 bidirectional 属性设置为 True,因为这在在线设置中无法使用。在这种情况下,默认值是正确的。

lstm = LSTMModule(
    input_size=n_cells,
    hidden_size=128,
    device=device,
    in_key="embed",
    out_key="embed",
)

让我们看一下 LSTM 模块类,特别是它的 in_keys 和 out_keys

print("in_keys", lstm.in_keys)
print("out_keys", lstm.out_keys)

我们可以看到这些值包含我们指定为 in_key(和 out_key)的键以及循环键名称。out_keys 前面带有“next”前缀,表示它们需要写入“next” TensorDict 中。我们使用此约定(可以通过传递 in_keys/out_keys 参数来覆盖)来确保调用 step_mdp() 将把循环状态移动到根 TensorDict,使其在下一次调用时可供 RNN 使用(参见引言中的图)。

如前所述,我们还可以向环境添加一个可选转换,以确保循环状态传递给缓冲区。make_tensordict_primer() 方法正是这样做的。

env.append_transform(lstm.make_tensordict_primer())

就是这样!在添加了 primer 之后,我们可以打印环境来检查一切看起来是否正常。

print(env)

MLP#

我们使用单层 MLP 来表示我们将用于策略的动作值。

mlp = MLP(
    out_features=2,
    num_cells=[
        64,
    ],
    device=device,
)

并将偏置填充为零

mlp[-1].bias.data.fill_(0.0)
mlp = Mod(mlp, in_keys=["embed"], out_keys=["action_value"])

使用 Q 值选择动作#

我们策略的最后一部分是 Q 值模块。QValueModule 将读取由我们的 MLP 产生的 "action_values" 键,并从中选出最大值的动作。我们需要做的唯一一件事就是指定动作空间,这可以通过传递字符串或动作规范(action-spec)来完成。这允许我们使用类别(Categorical,有时称为“稀疏”)编码或其 one-hot 版本。

qval = QValueModule(spec=env.action_spec)

注意

TorchRL 还提供了一个包装类 torchrl.modules.QValueActor,它将模块封装在一个 Sequential 中,并结合了一个像我们这里显式操作的 QValueModule。这样做的好处很少,而且过程透明度较低,但最终结果将与我们这里所做的类似。

现在我们可以将这些内容整合到 TensorDictSequential

stoch_policy = Seq(feature, lstm, mlp, qval)

DQN 是一种确定性算法,探索是其中的关键部分。我们将使用 \(\epsilon\)-greedy 策略,epsilon 为 0.2,并逐渐衰减至 0。这种衰减是通过调用 step() 实现的(见下方的训练循环)。

exploration_module = EGreedyModule(
    annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2
)
stoch_policy = Seq(
    stoch_policy,
    exploration_module,
)

将模型用于损失计算#

我们构建的模型非常适合在顺序设置中使用。但是,torch.nn.LSTM 类可以使用 cuDNN 优化的后端在 GPU 设备上更快地运行 RNN 序列。我们不想错过这样加速训练循环的机会!要使用它,我们只需告诉 LSTM 模块在被损失计算使用时运行在“recurrent-mode”下。因为我们通常想要有两个 LSTM 模块的副本,所以我们通过调用 set_recurrent_mode() 方法来实现,该方法将返回一个新的 LSTM 实例(具有共享权重),它将假定输入数据本质上是顺序的。

policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval)

因为我们还有几个未初始化的参数,所以我们应该在创建优化器等之前初始化它们。

policy(env.reset())

DQN 损失#

DQN 损失需要我们传入策略以及动作空间。虽然这看起来多余,但很重要,因为我们希望确保 DQNLossQValueModule 类是兼容的,但又不会过度依赖彼此。

为了使用 Double-DQN,我们需要一个 delay_value 参数,它将创建一个非微分的网络参数副本,用作目标网络。

loss_fn = DQNLoss(policy, action_space=env.action_spec, delay_value=True)

由于我们使用的是双 DQN,我们需要更新目标参数。我们将使用 SoftUpdate 实例来执行这项工作。

updater = SoftUpdate(loss_fn, eps=0.95)

optim = torch.optim.Adam(policy.parameters(), lr=3e-4)

收集器和重放缓冲区#

我们构建了最简单的数据收集器。我们将尝试用一百万帧来训练我们的算法,每次扩展缓冲区 50 帧。该缓冲区旨在存储 2 万条每条 50 步的轨迹。在每个优化步骤(每次数据收集进行 16 次)中,我们将从缓冲区中收集 4 个项目,总共 200 个转换。我们将使用 LazyMemmapStorage 存储来将数据保存在磁盘上。

注意

为了效率起见,我们在这里只运行了几千次迭代。在实际设置中,总帧数应设置为 1M。

collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device)
rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10
)

训练循环#

为了跟踪进度,我们将每 50 次数据收集在环境中运行一次策略,并在训练后绘制结果。

utd = 16
pbar = tqdm.tqdm(total=1_000_000)
longest = 0

traj_lens = []
for i, data in enumerate(collector):
    if i == 0:
        print(
            "Let us print the first batch of data.\nPay attention to the key names "
            "which will reflect what can be found in this data structure, in particular: "
            "the output of the QValueModule (action_values, action and chosen_action_value),"
            "the 'is_init' key that will tell us if a step is initial or not, and the "
            "recurrent_state keys.\n",
            data,
        )
    pbar.update(data.numel())
    # it is important to pass data that is not flattened
    rb.extend(data.unsqueeze(0).to_tensordict().cpu())
    for _ in range(utd):
        s = rb.sample().to(device, non_blocking=True)
        loss_vals = loss_fn(s)
        loss_vals["loss"].backward()
        optim.step()
        optim.zero_grad()
    longest = max(longest, data["step_count"].max().item())
    pbar.set_description(
        f"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}"
    )
    exploration_module.step(data.numel())
    updater.step()

    with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
        rollout = env.rollout(10000, stoch_policy)
        traj_lens.append(rollout.get(("next", "step_count")).max().item())

让我们绘制结果

if traj_lens:
    from matplotlib import pyplot as plt

    plt.plot(traj_lens)
    plt.xlabel("Test collection")
    plt.title("Test trajectory lengths")

结论#

我们已经了解了如何将 RNN 合并到 TorchRL 的策略中。现在您应该能够:

  • 创建一个作为 TensorDictModule 的 LSTM 模块

  • 通过 InitTracker 转换向 LSTM 模块指示何时需要重置

  • 将此模块合并到策略和损失模块中

  • 确保收集器意识到循环状态条目,以便它们可以与其余数据一起存储在重放缓冲区中

进一步阅读#

  • TorchRL 文档可以在这里找到。