快捷方式

TorchRL 入门

此演示在 ICML 2022 的行业演示日上进行了介绍。

它对 TorchRL 的功能进行了很好的概述。如果您对此有任何疑问或意见,请随时联系 vmoens@fb.com 或提交 issue。

TorchRL 是一个面向 PyTorch 的开源强化学习 (RL) 库。

https://github.com/pytorch/rl

PyTorch 生态系统团队 (Meta) 决定投资于该库,以提供一个领先的平台来开发研究环境中的 RL 解决方案。

它为 RL 提供了 PyTorch 和 **以 Python 为中心** 的低级和高级 **抽象** #,旨在高效、有文档且经过适当测试。代码旨在支持 RL 研究。大部分代码以高度模块化的方式用 Python 编写,研究人员可以轻松地替换组件、转换它们或轻松编写新组件。

该存储库尝试与现有的 PyTorch 生态系统库保持一致,因为它有一个数据集支柱 (torchrl/envs)、转换、模型、数据实用程序 (例如收集器和容器) 等。TorchRL 旨在尽可能减少依赖项 (Python 标准库、numpy 和 PyTorch)。常见的环境库 (例如 OpenAI gym) 仅为可选。

内容:
../_images/aafig-7040c2adc5d51442b6d7aebd109ebfeb53d8fc90.svg

与其他领域不同,RL 更多的是关于 *算法* 而非媒体。因此,制作真正独立的组件更加困难。

TorchRL 不是什么

  • 算法集合:我们不打算提供 RL 算法的 SOTA 实现,但我们仅提供这些算法作为如何使用该库的示例。

  • 研究框架:TorchRL 的模块化有两种形式。首先,我们尝试构建可重用的组件,以便可以轻松地相互替换。其次,我们尽最大努力确保组件可以独立于库的其他部分使用。

TorchRL 的核心依赖项很少,主要是 PyTorch 和 numpy。所有其他依赖项 (gym、torchvision、wandb / tensorboard) 均为可选。

数据

TensorDict

import torch
from tensordict import TensorDict

让我们创建一个 TensorDict。构造函数接受多种不同的格式,例如传递字典或使用关键字参数

batch_size = 5
data = TensorDict(
    key1=torch.zeros(batch_size, 3),
    key2=torch.zeros(batch_size, 5, 6, dtype=torch.bool),
    batch_size=[batch_size],
)
print(data)

您可以沿其 batch_size 对 TensorDict 进行索引,也可以查询键。

print(data[2])
print(data["key1"] is data.get("key1"))

以下内容显示了如何堆叠多个 TensorDict。这在编写 rollouts 循环时特别有用!

data1 = TensorDict(
    {
        "key1": torch.zeros(batch_size, 1),
        "key2": torch.zeros(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)

data2 = TensorDict(
    {
        "key1": torch.ones(batch_size, 1),
        "key2": torch.ones(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)

data = torch.stack([data1, data2], 0)
data.batch_size, data["key1"]

这里是 TensorDict 的一些其他功能:查看、置换、共享内存或扩展。

print(
    "view(-1): ",
    data.view(-1).batch_size,
    data.view(-1).get("key1").shape,
)

print("to device: ", data.to("cpu"))

# print("pin_memory: ", data.pin_memory())

print("share memory: ", data.share_memory_())

print(
    "permute(1, 0): ",
    data.permute(1, 0).batch_size,
    data.permute(1, 0).get("key1").shape,
)

print(
    "expand: ",
    data.expand(3, *data.batch_size).batch_size,
    data.expand(3, *data.batch_size).get("key1").shape,
)

您也可以创建 **嵌套数据**。

data = TensorDict(
    source={
        "key1": torch.zeros(batch_size, 3),
        "key2": TensorDict(
            source={"sub_key1": torch.zeros(batch_size, 2, 1)},
            batch_size=[batch_size, 2],
        ),
    },
    batch_size=[batch_size],
)
data

回放缓冲区

回放缓冲区 是许多 RL 算法的关键组件。TorchRL 提供了多种回放缓冲区实现。大多数基本功能将与任何数据结构 (列表、元组、字典) 一起使用,但要充分利用回放缓冲区并获得快速的读写访问,应优先使用 TensorDict API。

from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer

rb = ReplayBuffer(collate_fn=lambda x: x)

可以通过 add() (n=1) 或 extend() (n>1) 进行添加。

rb.add(1)
rb.sample(1)
rb.extend([2, 3])
rb.sample(3)

也可以使用优先回放缓冲区

rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)
rb.update_priority(1, 0.5)

以下是使用带有 data_stack 的 replaybuffer 的示例。使用它们可以轻松地为多个用例抽象化回放缓冲区的行为。

collate_fn = torch.stack
rb = ReplayBuffer(collate_fn=collate_fn)
rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[]))
len(rb)

rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
print(len(rb))
print(rb.sample(10))
print(rb.sample(2).contiguous())

torch.manual_seed(0)
from torchrl.data import TensorDictPrioritizedReplayBuffer

rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error")
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
data_sample = rb.sample(2).contiguous()
print(data_sample)

print(data_sample["index"])

data_sample["td_error"] = torch.rand(2)
rb.update_tensordict_priority(data_sample)

for i, val in enumerate(rb._sampler._sum_tree):
    print(i, val)
    if i == len(rb):
        break

环境

TorchRL 提供了多种 环境 包装器和实用程序。

Gym 环境

try:
    import gymnasium as gym
except ModuleNotFoundError:
    import gym

from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend

gym_env = gym.make("Pendulum-v1")
env = GymWrapper(gym_env)
env = GymEnv("Pendulum-v1")

data = env.reset()
env.rand_step(data)

更改环境配置

env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env.reset()

env.close()
del env

from torchrl.envs import (
    Compose,
    NoopResetEnv,
    ObservationNorm,
    ToTensorImage,
    TransformedEnv,
)

base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))

环境转换

转换就像 Gym 包装器一样,但其 API 更接近 torchvision 的 torch.distributions 转换。有多种 转换 可供选择。

from torchrl.envs import (
    Compose,
    NoopResetEnv,
    ObservationNorm,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)

base_env = GymEnv("HalfCheetah-v4", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env = env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))

env.reset()

print("env: ", env)
print("last transform parent: ", env.transform[2].parent)

向量化环境

向量化/并行环境可以提供显著的速度提升。

from torchrl.envs import ParallelEnv


def make_env():
    # You can control whether to use gym or gymnasium for your env
    with set_gym_backend("gym"):
        return GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)


base_env = ParallelEnv(
    4,
    make_env,
    mp_start_method="fork",  # This will break on Windows machines! Remove and decorate with if __name__ == "__main__"
)
env = TransformedEnv(
    base_env, Compose(StepCounter(), ToTensorImage())
)  # applies transforms on batch of envs
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
env.reset()

print(env.action_spec)

env.close()
del env

模块

库中可以找到多种 模块 (实用程序、模型和包装器)。

模型

MLP 模型示例

from torch import nn
from torchrl.modules import ConvNet, MLP
from torchrl.modules.models.utils import SquashDims

net = MLP(num_cells=[32, 64], out_features=4, activation_class=nn.ELU)
print(net)
print(net(torch.randn(10, 3)).shape)

CNN 模型示例

cnn = ConvNet(
    num_cells=[32, 64],
    kernel_sizes=[8, 4],
    strides=[2, 1],
    aggregator_class=SquashDims,
)
print(cnn)
print(cnn(torch.randn(10, 3, 32, 32)).shape)  # last tensor is squashed

TensorDict 模块

一些模块 专门设计用于处理 tensordict 输入。

from tensordict.nn import TensorDictModule

data = TensorDict({"key1": torch.randn(10, 3)}, batch_size=[10])
module = nn.Linear(3, 4)
td_module = TensorDictModule(module, in_keys=["key1"], out_keys=["key2"])
td_module(data)
print(data)

模块序列

通过 TensorDictSequential 可以轻松创建模块序列

from tensordict.nn import TensorDictSequential

backbone_module = nn.Linear(5, 3)
backbone = TensorDictModule(
    backbone_module, in_keys=["observation"], out_keys=["hidden"]
)
actor_module = nn.Linear(3, 4)
actor = TensorDictModule(actor_module, in_keys=["hidden"], out_keys=["action"])
value_module = MLP(out_features=1, num_cells=[4, 5])
value = TensorDictModule(value_module, in_keys=["hidden", "action"], out_keys=["value"])

sequence = TensorDictSequential(backbone, actor, value)
print(sequence)

print(sequence.in_keys, sequence.out_keys)

data = TensorDict(
    {"observation": torch.randn(3, 5)},
    [3],
)
backbone(data)
actor(data)
value(data)

data = TensorDict(
    {"observation": torch.randn(3, 5)},
    [3],
)
sequence(data)
print(data)

函数式编程 (Ensembling / Meta-RL)

函数式调用从未如此简单。使用 from_module() 提取参数,并使用 to_module() 替换它们

from tensordict import from_module

params = from_module(sequence)
print("extracted params", params)

使用 tensordict 的函数式调用

with params.to_module(sequence):
    data = sequence(data)

VMAP

快速执行多个相似架构的副本对于快速训练模型至关重要。 vmap() 正是为了实现这一点而量身定制的

from torch import vmap

params_expand = params.expand(4)


def exec_sequence(params, data):
    with params.to_module(sequence):
        return sequence(data)


tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, data)
print(tensordict_exp)

专用类

TorchRL 还提供了一些对输出值进行检查的专用模块。

torch.manual_seed(0)
from torchrl.data import Bounded
from torchrl.modules import SafeModule

spec = Bounded(-torch.ones(3), torch.ones(3))
base_module = nn.Linear(5, 3)
module = SafeModule(
    module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True
)
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
module(data)["action"]

data = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[])
module(data)["action"]  # safe=True projects the result within the set

Actor 类具有预定义的输出键 ("action")

from torchrl.modules import Actor

base_module = nn.Linear(5, 3)
actor = Actor(base_module, in_keys=["obs"])
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
actor(data)  # action is the default value

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)

得益于 tensordict.nn API,使用概率模型也变得很容易

from torchrl.modules import NormalParamExtractor, TanhNormal

td = TensorDict({"input": torch.randn(3, 5)}, [3])
net = nn.Sequential(
    nn.Linear(5, 4), NormalParamExtractor()
)  # splits the output in loc and scale
module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    module,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=TanhNormal,
        return_log_prob=False,
    ),
)
td_module(td)
print(td)
# returning the log-probability
td = TensorDict({"input": torch.randn(3, 5)}, [3])
td_module = ProbabilisticTensorDictSequential(
    module,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=TanhNormal,
        return_log_prob=True,
    ),
)
td_module(td)
print(td)

通过上下文管理器 set_exploration_type 可以控制随机性和采样策略

from torchrl.envs.utils import ExplorationType, set_exploration_type

td = TensorDict({"input": torch.randn(3, 5)}, [3])

torch.manual_seed(0)
with set_exploration_type(ExplorationType.RANDOM):
    td_module(td)
    print("random:", td["action"])

with set_exploration_type(ExplorationType.DETERMINISTIC):
    td_module(td)
    print("mode:", td["action"])

使用环境和模块

让我们看看环境和模块如何结合使用

from torchrl.envs.utils import step_mdp

env = GymEnv("Pendulum-v1")

action_spec = env.action_spec
actor_module = nn.Linear(3, 1)
actor = SafeModule(
    actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"]
)

torch.manual_seed(0)
env.set_seed(0)

max_steps = 100
data = env.reset()
data_stack = TensorDict(batch_size=[max_steps])
for i in range(max_steps):
    actor(data)
    data_stack[i] = env.step(data)
    if data["done"].any():
        break
    data = step_mdp(data)  # roughly equivalent to obs = next_obs

tensordicts_prealloc = data_stack.clone()
print("total steps:", i)
print(data_stack)
# equivalent
torch.manual_seed(0)
env.set_seed(0)

max_steps = 100
data = env.reset()
data_stack = []
for _ in range(max_steps):
    actor(data)
    data_stack.append(env.step(data))
    if data["done"].any():
        break
    data = step_mdp(data)  # roughly equivalent to obs = next_obs
tensordicts_stack = torch.stack(data_stack, 0)
print("total steps:", i)
print(tensordicts_stack)
(tensordicts_stack == tensordicts_prealloc).all()
torch.manual_seed(0)
env.set_seed(0)
tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps)
tensordict_rollout


(tensordict_rollout == tensordicts_prealloc).all()

from tensordict.nn import TensorDictModule

收集器

我们还提供了一系列 数据收集器,它们可以自动收集所需的每批数据帧。它们适用于从单节点、单工作到多节点、多工作的所有设置。

from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector

from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs.libs.gym import GymEnv

EnvCreator 确保我们可以将 lambda 函数从一个进程发送到另一个进程。我们使用 SerialEnv 以简化 (单工作),但对于较大的作业,ParallelEnv (多工作) 会更合适。

注意

多进程环境和多进程收集器可以结合使用!

parallel_env = SerialEnv(
    3,
    EnvCreator(lambda: GymEnv("Pendulum-v1")),
)
create_env_fn = [parallel_env, parallel_env]

actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])

同步多进程数据收集器

devices = ["cpu", "cpu"]

collector = MultiSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    device=devices,
)
for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector

异步多进程数据收集器

此类允许您在模型训练时收集数据。这在离策略设置中尤其有用,因为它将推理和模型训练分离开。数据以先到先服务的原则交付 (工作者将排队它们的結果)

collector = MultiaSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    device=devices,
)

for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
del create_env_fn
del parallel_env

目标

目标 是编写新算法时的主要入口点。

from torchrl.objectives import DDPGLoss

actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])


class ConcatModule(nn.Linear):
    def forward(self, obs, action):
        return super().forward(torch.cat([obs, action], -1))


value_module = ConcatModule(4, 1)
value = TensorDictModule(
    value_module, in_keys=["observation", "action"], out_keys=["state_action_value"]
)

loss_fn = DDPGLoss(actor, value)
loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99)
data = TensorDict(
    {
        "observation": torch.randn(10, 3),
        "next": {
            "observation": torch.randn(10, 3),
            "reward": torch.randn(10, 1),
            "done": torch.zeros(10, 1, dtype=torch.bool),
        },
        "action": torch.randn(10, 1),
    },
    batch_size=[10],
    device="cpu",
)
loss_td = loss_fn(data)

print(loss_td)

print(data)

安装库

该库在 PyPI 上:pip install torchrl 有关更多信息,请参阅 README

贡献

我们正在积极寻找贡献者和早期用户。如果您正在从事 RL (或者只是好奇),请尝试一下!给我们反馈:TorchRL 的成功将取决于它在多大程度上满足研究人员的需求。为此,我们需要他们的意见!由于该库尚处于起步阶段,现在是塑造它的绝佳时机!

有关更多信息,请参阅 贡献指南

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源