• 文档 >
  • 使用预训练模型
快捷方式

使用预训练模型

本教程解释了如何在 TorchRL 中使用预训练模型。

import tempfile

在本教程结束时,您将能够使用预训练模型进行高效的图像表示,并对其进行微调。

TorchRL 提供了预训练模型,它们既可以作为变换,也可以作为策略的组件。由于语义相同,它们可以在这两种上下文中互换使用。在本教程中,我们将使用 R3M(https://arxiv.org/abs/2203.12601),但其他模型(例如 VIP)同样有效。

import torch.cuda
from tensordict.nn import TensorDictSequential
from torch import nn
from torchrl.envs import Compose, R3MTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import Actor

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

让我们先创建一个环境。为了简单起见,我们将使用一个常见的 gym 环境。实际上,这将适用于更具挑战性的、具身 AI 的场景(例如,请看我们的 Habitat 包装器)。

base_env = GymEnv("Ant-v4", from_pixels=True, device=device)

让我们获取预训练模型。我们通过 download=True 标志请求模型的预训练版本。默认情况下,此选项处于关闭状态。接下来,我们将把我们的变换附加到环境中。实际上,将会发生的是,收集到的每个数据批次都将通过变换,并被映射到输出 tensordict 的“r3m_vec”条目中。我们的策略由一个单层 MLP 组成,然后它将读取此向量并计算相应的动作。

r3m = R3MTransform(
    "resnet50",
    in_keys=["pixels"],
    download=False,  # Turn to true for real-life testing
)
env_transformed = TransformedEnv(base_env, r3m)
net = nn.Sequential(
    nn.LazyLinear(128, device=device),
    nn.Tanh(),
    nn.Linear(128, base_env.action_spec.shape[-1], device=device),
)
policy = Actor(net, in_keys=["r3m_vec"])

让我们检查一下策略的参数数量

print("number of params:", len(list(policy.parameters())))

我们收集一个 32 步的 rollout 并打印其输出

rollout = env_transformed.rollout(32, policy)
print("rollout with transform:", rollout)

为了进行微调,我们在使参数可训练之后将变换集成到策略中。实际上,将此限制在部分参数(例如 MLP 的最后一层)可能更明智。

r3m.train()
policy = TensorDictSequential(r3m, policy)
print("number of params after r3m is integrated:", len(list(policy.parameters())))

再次,我们使用 R3M 收集一个 rollout。输出的结构略有变化,因为现在环境返回像素(而不是嵌入)。嵌入“r3m_vec”是我们策略的中间结果。

rollout = base_env.rollout(32, policy)
print("rollout, fine tuning:", rollout)

我们之所以能够轻松地将变换从环境切换到策略,是因为两者都表现得像 TensorDictModule:它们都有一个 “in_keys”“out_keys” 集合,使得在不同上下文中轻松读取和写入输出。

为了结束本教程,让我们看看如何使用 R3M 读取回放缓冲区中存储的图像(例如,在离线 RL 上下文中)。首先,让我们构建我们的数据集

from torchrl.data import LazyMemmapStorage, ReplayBuffer

buffer_scratch_dir = tempfile.TemporaryDirectory().name
storage = LazyMemmapStorage(1000, scratch_dir=buffer_scratch_dir)
rb = ReplayBuffer(storage=storage, transform=Compose(lambda td: td.to(device), r3m))

我们现在可以收集数据(为我们的目的进行随机 rollout)并用其填充回放缓冲区

total = 0
while total < 1000:
    tensordict = base_env.rollout(1000)
    rb.extend(tensordict)
    total += tensordict.numel()

让我们检查一下我们的回放缓冲区存储的样子。它不应该包含“r3m_vec”条目,因为我们还没有使用它

print("stored data:", storage._storage)

在采样时,数据将通过 R3M 变换,为我们提供所需的已处理数据。这样,我们就可以在由图像组成的数据集上离线训练算法

batch = rb.sample(32)
print("data after sampling:", batch)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源