• 文档 >
  • 导出 TorchRL 模块
快捷方式

导出 TorchRL 模块

作者Vincent Moens

注意

要在 notebook 中运行本教程,请在开头添加一个安装单元格,其中包含:

!pip install tensordict
!pip install torchrl
!pip install "gymnasium[atari]"

介绍

学习到的策略如果无法在真实环境中部署,其价值就微乎其微。正如其他教程所示,TorchRL 极其注重模块化和可组合性:得益于 tensordict,该库的组件可以以最通用的方式编写,通过将它们的签名抽象为对输入 TensorDict 的一系列操作。这可能会给人一种错觉,认为该库仅限于训练使用,因为典型的底层执行硬件(边缘设备、机器人、Arduino、Raspberry Pi)不执行 Python 代码,更不用说安装了 pytorch、tensordict 或 torchrl。

幸运的是,PyTorch 提供了一整套解决方案,用于将代码和训练好的模型导出到设备和硬件,TorchRL 完全具备与这些解决方案交互的能力。可以选择各种后端,包括本教程中示例化的 ONNX 或 AOTInductor。本教程简要概述了如何将训练好的模型隔离并打包成独立的、可执行的文件,以便导出到硬件上。

主要学习内容

  • 在训练后导出任何 TorchRL 模块;

  • 使用各种后端;

  • 测试导出的模型。

快速回顾:一个简单的 TorchRL 训练循环

在本节中,我们将重现上一篇入门教程中的训练循环,并稍作调整,以便与 gymnasium 库渲染的 Atari 游戏一起使用。我们将继续使用 DQN 示例,并展示稍后如何使用输出值分布的策略。

import time
from pathlib import Path

import numpy as np

import torch

from tensordict.nn import (
    TensorDictModule as Mod,
    TensorDictSequential,
    TensorDictSequential as Seq,
)

from torch.optim import Adam

from torchrl._utils import timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

from torchrl.envs import (
    Compose,
    GrayScale,
    GymEnv,
    Resize,
    set_exploration_type,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)

from torchrl.modules import ConvNet, EGreedyModule, QValueModule

from torchrl.objectives import DQNLoss, SoftUpdate

torch.manual_seed(0)

env = TransformedEnv(
    GymEnv("ALE/Pong-v5", categorical_action_encoding=True),
    Compose(
        ToTensorImage(), Resize(84, interpolation="nearest"), GrayScale(), StepCounter()
    ),
)
env.set_seed(0)

value_mlp = ConvNet.default_atari_dqn(num_actions=env.action_spec.space.n)
value_net = Mod(value_mlp, in_keys=["pixels"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)

init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
    env,
    policy_explore,
    frames_per_batch=frames_per_batch,
    total_frames=-1,
    init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters())
updater = SoftUpdate(loss, eps=0.99)

total_count = 0
total_episodes = 0
t0 = time.time()
for data in collector:
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max()
    if len(rb) > init_rand_steps:
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            exploration_module.step(data.numel())
            # Update target params
            updater.step()
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 200:
        break

导出基于 TensorDictModule 的策略

TensorDict 使我们能够构建一个具有高度灵活性的策略:从一个常规的 Module(该模块从观察值输出动作值),我们添加了一个 QValueModule 模块,该模块读取这些值并使用某种启发式方法(例如,argmax 调用)计算动作。

然而,在我们的例子中有一个小技术难点:环境(实际的 Atari 游戏)返回的不是灰度、84x84 的图像,而是原始的全屏彩色图像。我们附加到环境中的转换确保模型可以读取图像。我们可以看到,从训练的角度来看,环境和模型之间的界限是模糊的,但在执行时,情况就清晰多了:模型应该负责将输入数据(图像)转换为我们的 CNN 可以处理的格式。

在这里,tensordict 的魔力将再次帮助我们:事实证明,TorchRL 的大多数本地(非递归)转换都可以用作环境转换或 Module 实例内的预处理块。让我们看看如何将它们前置到我们的策略中。

policy_transform = TensorDictSequential(
    env.transform[
        :-1
    ],  # the last transform is a step counter which we don't need for preproc
    policy_explore.requires_grad_(
        False
    ),  # Using the explorative version of the policy for didactic purposes, see below.
)

我们创建一个假的输入,并将其传递给 export() 和策略。这将产生一个“原始”的 Python 函数,该函数将读取我们的输入张量并输出一个动作,而无需引用 TorchRL 或 tensordict 模块。

一个好的做法是调用 select_out_keys(),让模型知道我们只需要一组特定的输出(以防策略返回多个张量)。

fake_td = env.base_env.fake_tensordict()
pixels = fake_td["pixels"]
with set_exploration_type("DETERMINISTIC"):
    exported_policy = torch.export.export(
        # Select only the "action" output key
        policy_transform.select_out_keys("action"),
        args=(),
        kwargs={"pixels": pixels},
        strict=False,
    )

表示策略可能很有启发性:我们可以看到第一个操作是 permute、div、unsqueeze、resize,然后是卷积层和 MLP 层。

print("Deterministic policy")
exported_policy.graph_module.print_readable()

作为最后的检查,我们可以使用一个 dummy 输入来执行策略。输出(对于单个图像)应该是 0 到 6 之间的整数,代表要在游戏中执行的动作。

output = exported_policy.module()(pixels=pixels)
print("Exported module output", output)

有关导出 TensorDictModule 实例的更多详细信息,请参阅 tensordict 的 文档

注意

导出接受和输出嵌套键的模块是完全可以的。相应的 kwargs 将是键的 “_”.join(key) 版本,即 (“group0”, “agent0”, “obs”) 键将对应于 “group0_agent0_obs” 关键字参数。冲突的键(例如,(“group0_agent0”, “obs”)(“group0”, “agent0_obs”))可能会导致未定义行为,应不惜一切代价避免。显然,键名也应始终生成有效的关键字参数,即它们不应包含特殊字符,如空格或逗号。

torch.export 还有许多其他功能,我们将在下面进一步探讨。在此之前,让我们对测试时推理、以及递归策略的探索和随机策略做一个小小的离题。

处理随机策略

您可能已经注意到,上面我们使用了 set_exploration_type 上下文管理器来控制策略的行为。如果策略是随机的(例如,策略输出动作空间的分布,就像 PPO 或其他类似的在线策略算法一样)或具有探索性(带有附加的探索模块,如 E-Greedy、加性高斯或 Ornstein-Uhlenbeck),我们可能希望或不希望在其导出的版本中使用该探索策略。幸运的是,导出工具可以理解该上下文管理器,只要导出发生在正确的上下文管理器内,策略的行为就应该与指示的一致。为了证明这一点,让我们尝试另一种探索类型。

with set_exploration_type("RANDOM"):
    exported_stochastic_policy = torch.export.export(
        policy_transform.select_out_keys("action"),
        args=(),
        kwargs={"pixels": pixels},
        strict=False,
    )

我们的导出策略现在应该在调用堆栈的末尾有一个随机模块,这与之前的版本不同。事实上,最后三个操作是:生成一个 0 到 6 之间的随机整数,使用一个随机掩码,并根据掩码中的值选择网络输出或随机动作。

print("Stochastic policy")
exported_stochastic_policy.graph_module.print_readable()

处理递归策略

另一个典型用例是递归策略,它将输出一个动作以及一个或多个递归状态。LSTM 和 GRU 是基于 CuDNN 的模块,这意味着它们的行为与常规 Module 实例不同(导出工具可能无法很好地跟踪它们)。幸运的是,TorchRL 提供了这些模块的 Python 实现,可以在需要时替换 CuDNN 版本。

为了展示这一点,让我们编写一个依赖于 RNN 的原型策略。

from tensordict.nn import TensorDictModule
from torchrl.envs import BatchSizeTransform
from torchrl.modules import LSTMModule, MLP

lstm = LSTMModule(
    input_size=32,
    num_layers=2,
    hidden_size=256,
    in_keys=["observation", "hidden0", "hidden1"],
    out_keys=["intermediate", "hidden0", "hidden1"],
)

如果 LSTM 模块不是基于 Python 而是基于 CuDNN(LSTM),则可以使用 make_python_based() 方法来使用 Python 版本。

lstm = lstm.make_python_based()

现在让我们创建策略。我们将两个修改输入形状的层(unsqueeze/squeeze 操作)与 LSTM 和 MLP 结合起来。

recurrent_policy = TensorDictSequential(
    # Unsqueeze the first dim of all tensors to make LSTMCell happy
    BatchSizeTransform(reshape_fn=lambda x: x.unsqueeze(0)),
    lstm,
    TensorDictModule(
        MLP(in_features=256, out_features=5, num_cells=[64, 64]),
        in_keys=["intermediate"],
        out_keys=["action"],
    ),
    # Squeeze the first dim of all tensors to get the original shape back
    BatchSizeTransform(reshape_fn=lambda x: x.squeeze(0)),
)

与之前一样,我们选择相关的键。

recurrent_policy.select_out_keys("action", "hidden0", "hidden1")
print("recurrent policy input keys:", recurrent_policy.in_keys)
print("recurrent policy output keys:", recurrent_policy.out_keys)

我们现在准备导出。为此,我们构建假的输入并将它们传递给 export()

fake_obs = torch.randn(32)
fake_hidden0 = torch.randn(2, 256)
fake_hidden1 = torch.randn(2, 256)

# Tensor indicating whether the state is the first of a sequence
fake_is_init = torch.zeros((), dtype=torch.bool)

exported_recurrent_policy = torch.export.export(
    recurrent_policy,
    args=(),
    kwargs={
        "observation": fake_obs,
        "hidden0": fake_hidden0,
        "hidden1": fake_hidden1,
        "is_init": fake_is_init,
    },
    strict=False,
)
print("Recurrent policy graph:")
exported_recurrent_policy.graph_module.print_readable()

AOTInductor:将您的策略导出为不依赖 PyTorch 的 C++ 二进制文件

AOTInductor 是一个 PyTorch 模块,允许您将模型(策略或其他)导出为不依赖 PyTorch 的 C++ 二进制文件。当您需要在 PyTorch 不可用的设备或平台上部署模型时,这尤其有用。

以下是如何使用 AOTInductor 导出策略的示例,灵感来自 AOTI 文档

from tempfile import TemporaryDirectory

from torch._inductor import aoti_compile_and_package, aoti_load_package

with TemporaryDirectory() as tmpdir:
    path = str(Path(tmpdir) / "model.pt2")
    with torch.no_grad():
        pkg_path = aoti_compile_and_package(
            exported_policy,
            # Specify the generated shared library path
            package_path=path,
        )
    print("pkg_path", pkg_path)

    compiled_module = aoti_load_package(pkg_path)

print(compiled_module(pixels=pixels))

使用 ONNX 导出 TorchRL 模型

注意

要运行此脚本部分,请确保已安装 pytorch onnx。

!pip install onnx-pytorch
!pip install onnxruntime

您还可以通过 此处 找到更多关于在 PyTorch 生态系统中使用 ONNX 的信息。以下示例基于此文档。

在本节中,我们将展示如何以一种可以在不依赖 PyTorch 的情况下执行模型的方式导出模型。

网上有很多资源解释了 ONNX 如何用于在各种硬件和设备上部署 PyTorch 模型,包括 Raspberry PiNVIDIA TensorRTiOSAndroid

我们训练的 Atari 游戏可以使用 ALE 库 在没有 TorchRL 或 gymnasium 的情况下进行隔离,因此为我们提供了可以使用 ONNX 实现的良好示例。

让我们看看这个 API 是什么样的。

from ale_py import ALEInterface, roms

# Create the interface
ale = ALEInterface()
# Load the pong environment
ale.loadROM(roms.Pong)
ale.reset_game()

# Make a step in the simulator
action = 0
reward = ale.act(action)
screen_obs = ale.getScreenRGB()
print("Observation from ALE simulator:", type(screen_obs), screen_obs.shape)

from matplotlib import pyplot as plt

plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.imshow(screen_obs)
plt.title("Screen rendering of Pong game.")

导出到 ONNX 与上面的 Export/AOTI 非常相似。

import onnxruntime

with set_exploration_type("DETERMINISTIC"):
    # We use torch.onnx.dynamo_export to capture the computation graph from our policy_explore model
    pixels = torch.as_tensor(screen_obs)
    onnx_policy_export = torch.onnx.dynamo_export(policy_transform, pixels=pixels)

我们现在可以将程序保存在磁盘上并加载它。

with TemporaryDirectory() as tmpdir:
    onnx_file_path = str(Path(tmpdir) / "policy.onnx")
    onnx_policy_export.save(onnx_file_path)

    ort_session = onnxruntime.InferenceSession(
        onnx_file_path, providers=["CPUExecutionProvider"]
    )

onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs}
onnx_policy = ort_session.run(None, onnxruntime_input)

使用 ONNX 运行 rollout

我们现在有了一个运行我们策略的 ONNX 模型。让我们将其与原始 TorchRL 实例进行比较:由于其轻量级,ONNX 版本应该比 TorchRL 版本更快。

def onnx_policy(screen_obs: np.ndarray) -> int:  # noqa: F811
    onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs}
    onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
    action = int(onnxruntime_outputs[0])
    return action


with timeit("ONNX rollout"):
    num_steps = 1000
    ale.reset_game()
    for _ in range(num_steps):
        screen_obs = ale.getScreenRGB()
        action = onnx_policy(screen_obs)
        reward = ale.act(action)

with timeit("TorchRL version"), torch.no_grad(), set_exploration_type("DETERMINISTIC"):
    env.rollout(num_steps, policy_explore)

print(timeit.print())

请注意,ONNX 还提供了直接优化模型的可能性,但这超出了本教程的范围。

结论

在本教程中,我们学习了如何使用各种后端导出 TorchRL 模块,例如 PyTorch 的内置导出功能、AOTInductorONNX。我们演示了如何导出在 Atari 游戏上训练的模型,并使用 ALE 库在不依赖 PyTorch 的环境中运行它。我们还比较了原始 TorchRL 实例与导出的 ONNX 模型的性能。

关键要点

  • 导出 TorchRL 模块允许部署在未安装 PyTorch 的设备上。

  • AOTInductor 和 ONNX 提供了导出模型的替代后端。

  • 优化 ONNX 模型可以提高性能。

进一步阅读和学习步骤

  • 有关更多信息,请查看 PyTorch 的 导出功能AOTInductorONNX 的官方文档。

  • 尝试将导出的模型部署到不同的设备上。

  • 探索 ONNX 模型的优化技术以提高性能。

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源