• 文档 >
  • 导出 TorchRL 模块
快捷方式

导出 TorchRL 模块

作者Vincent Moens

注意

要在 notebook 中运行本教程,请在开头添加一个安装单元格,其中包含:

!pip install tensordict
!pip install torchrl
!pip install "gymnasium[atari]"

介绍

学习一个策略,如果该策略无法在实际环境中部署,其价值将很小。正如其他教程所示,TorchRL 非常注重模块化和可组合性:得益于 tensordict,库的组件可以以最通用的方式编写,通过将它们的签名抽象为对输入 TensorDict 的一系列操作。这可能给人一种该库只能用于训练的印象,因为典型的低级执行硬件(边缘设备、机器人、Arduino、Raspberry Pi)不执行 Python 代码,更不用说安装了 pytorch、tensordict 或 torchrl。

幸运的是,PyTorch 提供了一整套解决方案,用于将代码和训练好的模型导出到设备和硬件,TorchRL 完全能够与之交互。可以选择多种后端,包括本教程中示例的 ONNX 或 AOTInductor。本教程将快速概述如何隔离训练好的模型并将其作为独立的应用程序进行分发,以便导出到硬件。

主要学习内容

  • 在训练后导出任何 TorchRL 模块;

  • 使用各种后端;

  • 测试导出的模型。

快速回顾:一个简单的 TorchRL 训练循环

在本节中,我们将重现最后一个入门教程中的训练循环,并稍作调整以用于 gymnasium 库渲染的 Atari 游戏。我们将继续使用 DQN 示例,并展示稍后如何使用输出值分布的策略。

import time
from pathlib import Path

import numpy as np

import torch

from tensordict.nn import (
    TensorDictModule as Mod,
    TensorDictSequential,
    TensorDictSequential as Seq,
)

from torch.optim import Adam

from torchrl._utils import timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

from torchrl.envs import (
    Compose,
    GrayScale,
    GymEnv,
    Resize,
    set_exploration_type,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)

from torchrl.modules import ConvNet, EGreedyModule, QValueModule

from torchrl.objectives import DQNLoss, SoftUpdate

torch.manual_seed(0)

env = TransformedEnv(
    GymEnv("ALE/Pong-v5", categorical_action_encoding=True),
    Compose(
        ToTensorImage(), Resize(84, interpolation="nearest"), GrayScale(), StepCounter()
    ),
)
env.set_seed(0)

value_mlp = ConvNet.default_atari_dqn(num_actions=env.action_spec.space.n)
value_net = Mod(value_mlp, in_keys=["pixels"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)

init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
    env,
    policy_explore,
    frames_per_batch=frames_per_batch,
    total_frames=-1,
    init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters())
updater = SoftUpdate(loss, eps=0.99)

total_count = 0
total_episodes = 0
t0 = time.time()
for data in collector:
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max()
    if len(rb) > init_rand_steps:
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            exploration_module.step(data.numel())
            # Update target params
            updater.step()
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 200:
        break

导出基于 TensorDictModule 的策略

TensorDict 使我们能够构建一个具有高度灵活性的策略:从一个 regular Module(该模块根据观察值输出动作值),我们添加了一个 QValueModule 模块,该模块读取这些值并使用某种启发式方法(例如,argmax 调用)计算动作。

但是,在这种情况下存在一个小技术问题:环境(实际的 Atari 游戏)返回的不是灰度、84x84 的图像,而是原始屏幕大小的彩色图像。我们附加到环境的变换确保模型可以读取这些图像。我们可以看到,从训练的角度来看,环境和模型之间的界限是模糊的,但在执行时,情况就清楚多了:模型应该负责将输入数据(图像)转换为我们的 CNN 可以处理的格式。

这里,tensordict 的魔力再次发挥作用:事实证明,TorchRL 的大多数本地(非递归)变换都可以用作环境变换或 Module 实例中的预处理块。让我们看看如何将它们预置到我们的策略中。

policy_transform = TensorDictSequential(
    env.transform[
        :-1
    ],  # the last transform is a step counter which we don't need for preproc
    policy_explore.requires_grad_(
        False
    ),  # Using the explorative version of the policy for didactic purposes, see below.
)

我们创建一个假的输入,并将其与策略一起传递给 export()。这将生成一个“原始” Python 函数,该函数将读取我们的输入张量并输出一个动作,而无需引用 TorchRL 或 tensordict 模块。

一个好的做法是调用 select_out_keys(),让模型知道我们只想要特定的一组输出(以防策略返回的张量不止一个)。

fake_td = env.base_env.fake_tensordict()
pixels = fake_td["pixels"]
with set_exploration_type("DETERMINISTIC"):
    exported_policy = torch.export.export(
        # Select only the "action" output key
        policy_transform.select_out_keys("action"),
        args=(),
        kwargs={"pixels": pixels},
        strict=False,
    )

表示策略可以相当有启发性:我们可以看到,第一个操作是 permute、div、unsqueeze、resize,然后是卷积层和 MLP 层。

print("Deterministic policy")
exported_policy.graph_module.print_readable()

作为最后的检查,我们可以使用一个虚拟输入来执行策略。输出(对于单张图像)应该是一个从 0 到 6 的整数,表示要在游戏中执行的动作。

output = exported_policy.module()(pixels=pixels)
print("Exported module output", output)

有关导出 TensorDictModule 实例的更多详细信息,请参阅 tensordict 的 文档

注意

导出接收和输出嵌套键的模块是完全可以的。相应的 kwargs 将是键的 “_”.join(key) 版本,即 (“group0”, “agent0”, “obs”) 键将对应于 “group0_agent0_obs” 关键字参数。冲突的键(例如 (“group0_agent0”, “obs”)(“group0”, “agent0_obs”))可能导致未定义的行为,应不惜一切代价避免。显然,键名也应始终产生有效的关键字参数,即它们不应包含特殊字符,如空格或逗号。

torch.export 还有许多其他功能,我们将在下文中进一步探讨。在此之前,让我们先简单地探讨一下测试时推理、探索以及循环策略的背景。

使用随机策略

您可能已经注意到,上面我们使用了 set_exploration_type 上下文管理器来控制策略的行为。如果策略是随机的(例如,策略输出动作空间上的分布,如 PPO 或其他类似的在线策略算法),或者具有探索性(附加了 E-Greedy、加性高斯或 Ornstein-Uhlenbeck 等探索模块),我们可能希望或不希望在其导出的版本中使用该探索策略。幸运的是,导出工具能够理解该上下文管理器,只要导出发生在正确的上下文管理器内,策略的行为就应该与指示的一致。为了演示这一点,让我们尝试另一种探索类型。

with set_exploration_type("RANDOM"):
    exported_stochastic_policy = torch.export.export(
        policy_transform.select_out_keys("action"),
        args=(),
        kwargs={"pixels": pixels},
        strict=False,
    )

我们的导出的策略在调用堆栈的末尾应该有一个随机模块,与之前的版本不同。实际上,最后三个操作是:生成一个介于 0 到 6 之间的随机整数,使用随机掩码,并根据掩码中的值选择网络输出或随机动作。

print("Stochastic policy")
exported_stochastic_policy.graph_module.print_readable()

使用循环策略

另一个典型的用例是循环策略,它将输出一个动作以及一个或多个循环状态。LSTM 和 GRU 是基于 CuDNN 的模块,这意味着它们的行为与 regular Module 实例不同(导出工具可能无法很好地跟踪它们)。幸运的是,TorchRL 提供了这些模块的 Python 实现,可以在需要时替换 CuDNN 版本。

为了展示这一点,让我们编写一个依赖于 RNN 的原型策略。

from tensordict.nn import TensorDictModule
from torchrl.envs import BatchSizeTransform
from torchrl.modules import LSTMModule, MLP

lstm = LSTMModule(
    input_size=32,
    num_layers=2,
    hidden_size=256,
    in_keys=["observation", "hidden0", "hidden1"],
    out_keys=["intermediate", "hidden0", "hidden1"],
)

如果 LSTM 模块不是基于 Python 而是基于 CuDNN(LSTM),可以使用 torchrl.modules.LSTMModule.make_python_based() 方法来使用 Python 版本。

lstm = lstm.make_python_based()

现在让我们创建策略。我们将两个修改输入形状(unsqueeze/squeeze 操作)的层与 LSTM 和 MLP 结合起来。

recurrent_policy = TensorDictSequential(
    # Unsqueeze the first dim of all tensors to make LSTMCell happy
    BatchSizeTransform(reshape_fn=lambda x: x.unsqueeze(0)),
    lstm,
    TensorDictModule(
        MLP(in_features=256, out_features=5, num_cells=[64, 64]),
        in_keys=["intermediate"],
        out_keys=["action"],
    ),
    # Squeeze the first dim of all tensors to get the original shape back
    BatchSizeTransform(reshape_fn=lambda x: x.squeeze(0)),
)

与之前一样,我们选择相关的键。

recurrent_policy.select_out_keys("action", "hidden0", "hidden1")
print("recurrent policy input keys:", recurrent_policy.in_keys)
print("recurrent policy output keys:", recurrent_policy.out_keys)

我们现在准备导出。为此,我们构建假的输入并将它们传递给 export()

fake_obs = torch.randn(32)
fake_hidden0 = torch.randn(2, 256)
fake_hidden1 = torch.randn(2, 256)

# Tensor indicating whether the state is the first of a sequence
fake_is_init = torch.zeros((), dtype=torch.bool)

exported_recurrent_policy = torch.export.export(
    recurrent_policy,
    args=(),
    kwargs={
        "observation": fake_obs,
        "hidden0": fake_hidden0,
        "hidden1": fake_hidden1,
        "is_init": fake_is_init,
    },
    strict=False,
)
print("Recurrent policy graph:")
exported_recurrent_policy.graph_module.print_readable()

AOTInductor:将您的策略导出为不含 PyTorch 的 C++ 二进制文件

AOTInductor 是一个 PyTorch 模块,允许您将模型(策略或其他)导出为不含 PyTorch 的 C++ 二进制文件。当您需要在 PyTorch 不可用的设备或平台上部署模型时,这特别有用。

以下是如何使用 AOTInductor 导出策略的示例,该示例受 AOTI 文档 的启发。

from tempfile import TemporaryDirectory

from torch._inductor import aoti_compile_and_package, aoti_load_package

with TemporaryDirectory() as tmpdir:
    path = str(Path(tmpdir) / "model.pt2")
    with torch.no_grad():
        pkg_path = aoti_compile_and_package(
            exported_policy,
            # Specify the generated shared library path
            package_path=path,
        )
    print("pkg_path", pkg_path)

    compiled_module = aoti_load_package(pkg_path)

print(compiled_module(pixels=pixels))

使用 ONNX 导出 TorchRL 模型

注意

要运行此脚本部分,请确保已安装 pytorch onnx。

!pip install onnx-pytorch
!pip install onnxruntime

您也可以在此处找到有关在 PyTorch 生态系统中使用 ONNX 的更多信息。以下示例基于此文档。

在本节中,我们将展示如何以可以在不含 PyTorch 的环境中执行的方式导出我们的模型。

网络上有许多资源解释了如何使用 ONNX 将 PyTorch 模型部署到各种硬件和设备上,包括 Raspberry PiNVIDIA TensorRTiOSAndroid

我们训练过的 Atari 游戏可以使用 ALE 库 独立于 TorchRL 或 gymnasium,因此为我们提供了与 ONNX 相关的良好示例。

让我们看看这个 API 是什么样子。

from ale_py import ALEInterface, roms

# Create the interface
ale = ALEInterface()
# Load the pong environment
ale.loadROM(roms.Pong)
ale.reset_game()

# Make a step in the simulator
action = 0
reward = ale.act(action)
screen_obs = ale.getScreenRGB()
print("Observation from ALE simulator:", type(screen_obs), screen_obs.shape)

from matplotlib import pyplot as plt

plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.imshow(screen_obs)
plt.title("Screen rendering of Pong game.")

导出到 ONNX 与上面的 Export/AOTI 非常相似。

import onnxruntime

with set_exploration_type("DETERMINISTIC"):
    # We use torch.onnx.dynamo_export to capture the computation graph from our policy_explore model
    pixels = torch.as_tensor(screen_obs)
    onnx_policy_export = torch.onnx.dynamo_export(policy_transform, pixels=pixels)

我们现在可以将程序保存在磁盘上并加载它。

with TemporaryDirectory() as tmpdir:
    onnx_file_path = str(Path(tmpdir) / "policy.onnx")
    onnx_policy_export.save(onnx_file_path)

    ort_session = onnxruntime.InferenceSession(
        onnx_file_path, providers=["CPUExecutionProvider"]
    )

onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs}
onnx_policy = ort_session.run(None, onnxruntime_input)

使用 ONNX 运行 Rollout

我们现在有了一个运行我们策略的 ONNX 模型。让我们将其与原始 TorchRL 实例进行比较:由于它更轻量级,ONNX 版本应该比 TorchRL 版本更快。

def onnx_policy(screen_obs: np.ndarray) -> int:
    onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs}
    onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
    action = int(onnxruntime_outputs[0])
    return action


with timeit("ONNX rollout"):
    num_steps = 1000
    ale.reset_game()
    for _ in range(num_steps):
        screen_obs = ale.getScreenRGB()
        action = onnx_policy(screen_obs)
        reward = ale.act(action)

with timeit("TorchRL version"), torch.no_grad(), set_exploration_type("DETERMINISTIC"):
    env.rollout(num_steps, policy_explore)

print(timeit.print())

请注意,ONNX 还提供了直接优化模型的可能性,但这超出了本教程的范围。

结论

在本教程中,我们学习了如何使用 PyTorch 内置导出功能、AOTInductorONNX 等各种后端导出 TorchRL 模块。我们演示了如何导出在 Atari 游戏上训练的模型,并使用 ALE 库在不含 PyTorch 的环境中运行它。我们还比较了原始 TorchRL 实例和导出的 ONNX 模型的性能。

关键要点

  • 导出 TorchRL 模块允许部署到没有安装 PyTorch 的设备上。

  • AOTInductor 和 ONNX 提供了导出模型的替代后端。

  • 优化 ONNX 模型可以提高性能。

进一步阅读和学习步骤

  • 请查看 PyTorch 导出功能AOTInductorONNX 的官方文档以获取更多信息。

  • 尝试将导出的模型部署到不同的设备上。

  • 探索 ONNX 模型的优化技术以提高性能。

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源