注意
转到底部 下载完整的示例代码。
使用预训练模型¶
本教程将解释如何在 TorchRL 中使用预训练模型。
import tempfile
在本教程结束时,您将能够使用预训练模型进行高效的图像表示,并对其进行微调。
TorchRL 提供预训练模型,这些模型可用作变换(transforms)或策略(policy)的组件。由于语义相同,它们可以互换地用于其中一种或另一种上下文。在本教程中,我们将使用 R3M(https://arxiv.org/abs/2203.12601),但其他模型(例如 VIP)同样有效。
import torch.cuda
from tensordict.nn import TensorDictSequential
from torch import nn
from torchrl.envs import Compose, R3MTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import Actor
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
让我们先创建一个环境。为了简单起见,我们将使用一个通用的 gym 环境。实际上,这在更具挑战性的具身 AI 环境中也能正常工作(例如,看看我们的 Habitat 包装器)。
base_env = GymEnv("Ant-v4", from_pixels=True, device=device)
让我们获取预训练模型。我们通过 `download=True` 标志来请求模型的预训练版本。默认情况下,此选项是关闭的。接下来,我们将变换添加到环境中。实际上,发生的情况是,收集到的每个数据批次都将通过该变换,并在输出 tensordict 的“r3m_vec”条目中映射。然后,我们的策略(由单个 MLP 层组成)将读取此向量并计算相应的动作。
r3m = R3MTransform(
"resnet50",
in_keys=["pixels"],
download=False, # Turn to true for real-life testing
)
env_transformed = TransformedEnv(base_env, r3m)
net = nn.Sequential(
nn.LazyLinear(128, device=device),
nn.Tanh(),
nn.Linear(128, base_env.action_spec.shape[-1], device=device),
)
policy = Actor(net, in_keys=["r3m_vec"])
让我们检查策略的参数数量
print("number of params:", len(list(policy.parameters())))
我们收集 32 步的 rollout 并打印其输出
rollout = env_transformed.rollout(32, policy)
print("rollout with transform:", rollout)
对于微调,我们将变换集成到策略中,并将参数设置为可训练。实际上,限制在参数的子集(例如 MLP 的最后一层)上可能会更明智。
r3m.train()
policy = TensorDictSequential(r3m, policy)
print("number of params after r3m is integrated:", len(list(policy.parameters())))
再次,我们使用 R3M 收集一个 rollout。输出的结构略有变化,因为现在环境返回的是像素(而不是嵌入)。嵌入“r3m_vec”是策略的中间结果。
rollout = base_env.rollout(32, policy)
print("rollout, fine tuning:", rollout)
我们将变换从环境切换到策略的轻松程度,得益于两者都表现得像 `TensorDictModule`:它们都有一个 `“in_keys”` 和 `“out_keys”` 集合,使得在不同上下文中轻松读写输出成为可能。
为了结束本教程,让我们看看如何使用 R3M 读取存储在回放缓冲区中的图像(例如,在离线 RL 上下文中)。首先,让我们构建我们的数据集
from torchrl.data import LazyMemmapStorage, ReplayBuffer
buffer_scratch_dir = tempfile.TemporaryDirectory().name
storage = LazyMemmapStorage(1000, scratch_dir=buffer_scratch_dir)
rb = ReplayBuffer(storage=storage, transform=Compose(lambda td: td.to(device), r3m))
我们现在可以收集数据(为了我们的目的,是随机 rollout)并用它来填充回放缓冲区
total = 0
while total < 1000:
tensordict = base_env.rollout(1000)
rb.extend(tensordict)
total += tensordict.numel()
让我们检查一下我们的回放缓冲区存储是什么样的。它不应该包含“r3m_vec”条目,因为我们还没有使用它
print("stored data:", storage._storage)
在采样时,数据将通过 R3M 变换,为我们提供所需的处理后的数据。通过这种方式,我们可以使用由图像组成的数据集对算法进行离线训练
batch = rb.sample(32)
print("data after sampling:", batch)