注意
转到底部 下载完整的示例代码。
TorchRL 简介¶
此演示已在 ICML 2022 的行业演示日上展出。
它对 TorchRL 的功能进行了良好的概述。如果您有关于此演示的问题或评论,请随时联系 vmoens@fb.com 或提交 issue。
TorchRL 是一个用于 PyTorch 的开源强化学习 (RL) 库。
PyTorch 生态系统团队 (Meta) 已决定投资此库,以提供一个领先的平台来开发研究环境中的 RL 解决方案。
它提供 PyTorch 和 **以 Python 为优先** 的低级和高级 **抽象** #,旨在高效、文档齐全且经过适当测试。代码旨在支持 RL 研究。其中大部分是用高度模块化的 Python 编写的,以便研究人员可以轻松地交换组件、转换它们或轻松编写新组件。
此仓库试图与现有的 PyTorch 生态系统库保持一致,因为它有一个数据集支柱 (torchrl/envs)、转换、模型、数据实用程序 (例如收集器和容器) 等。TorchRL 旨在拥有尽可能少的依赖项 (Python 标准库、numpy 和 pytorch)。常见的环境库 (例如 OpenAI gym) 仅为可选。
与许多其他领域不同,RL 更侧重于 *算法* 而非媒体。因此,很难创建真正独立的组件。
TorchRL 不是什么
算法集合:我们不打算提供 SOTA 的 RL 算法实现,但我们仅提供这些算法作为如何使用该库的示例。
研究框架:TorchRL 的模块化有两种形式。首先,我们尝试构建可重用组件,以便可以轻松地将它们相互交换。其次,我们尽最大努力确保组件可以独立于库的其他部分使用。
TorchRL 的核心依赖项非常少,主要是 PyTorch 和 numpy。所有其他依赖项 (gym、torchvision、wandb / tensorboard) 都是可选的。
数据¶
TensorDict¶
import torch
from tensordict import TensorDict
让我们创建一个 TensorDict。构造函数接受许多不同的格式,例如通过字典或关键字参数传递
batch_size = 5
data = TensorDict(
key1=torch.zeros(batch_size, 3),
key2=torch.zeros(batch_size, 5, 6, dtype=torch.bool),
batch_size=[batch_size],
)
print(data)
您可以沿其 batch_size
索引 TensorDict,还可以查询键。
print(data[2])
print(data["key1"] is data.get("key1"))
以下展示了如何堆叠多个 TensorDict。在编写 rollout 循环时,这尤其有用!
data1 = TensorDict(
{
"key1": torch.zeros(batch_size, 1),
"key2": torch.zeros(batch_size, 5, 6, dtype=torch.bool),
},
batch_size=[batch_size],
)
data2 = TensorDict(
{
"key1": torch.ones(batch_size, 1),
"key2": torch.ones(batch_size, 5, 6, dtype=torch.bool),
},
batch_size=[batch_size],
)
data = torch.stack([data1, data2], 0)
data.batch_size, data["key1"]
这里有一些 TensorDict 的其他功能:查看、置换、共享内存或展开。
print(
"view(-1): ",
data.view(-1).batch_size,
data.view(-1).get("key1").shape,
)
print("to device: ", data.to("cpu"))
# print("pin_memory: ", data.pin_memory())
print("share memory: ", data.share_memory_())
print(
"permute(1, 0): ",
data.permute(1, 0).batch_size,
data.permute(1, 0).get("key1").shape,
)
print(
"expand: ",
data.expand(3, *data.batch_size).batch_size,
data.expand(3, *data.batch_size).get("key1").shape,
)
您也可以创建 **嵌套数据**。
data = TensorDict(
source={
"key1": torch.zeros(batch_size, 3),
"key2": TensorDict(
source={"sub_key1": torch.zeros(batch_size, 2, 1)},
batch_size=[batch_size, 2],
),
},
batch_size=[batch_size],
)
data
回放缓冲区¶
回放缓冲区 是许多 RL 算法中的关键组成部分。TorchRL 提供了一系列回放缓冲区实现。大多数基本功能将适用于任何数据结构 (列表、元组、字典),但要充分利用回放缓冲区并实现快速的读写访问,应优先使用 TensorDict API。
from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer
rb = ReplayBuffer(collate_fn=lambda x: x)
添加可以使用 add()
(n=1) 或 extend()
(n>1) 完成。
rb.add(1)
rb.sample(1)
rb.extend([2, 3])
rb.sample(3)
也可以使用优先级回放缓冲区
rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)
rb.update_priority(1, 0.5)
这里有一些使用 replaybuffer 与 data_stack 的示例。使用它们可以轻松地为多种用例抽象回放缓冲区的行为。
collate_fn = torch.stack
rb = ReplayBuffer(collate_fn=collate_fn)
rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[]))
len(rb)
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
print(len(rb))
print(rb.sample(10))
print(rb.sample(2).contiguous())
torch.manual_seed(0)
from torchrl.data import TensorDictPrioritizedReplayBuffer
rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error")
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
data_sample = rb.sample(2).contiguous()
print(data_sample)
print(data_sample["index"])
data_sample["td_error"] = torch.rand(2)
rb.update_tensordict_priority(data_sample)
for i, val in enumerate(rb._sampler._sum_tree):
print(i, val)
if i == len(rb):
break
环境¶
TorchRL 提供了一系列 环境 包装器和实用程序。
Gym 环境¶
try:
import gymnasium as gym
except ModuleNotFoundError:
import gym
from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend
gym_env = gym.make("Pendulum-v1")
env = GymWrapper(gym_env)
env = GymEnv("Pendulum-v1")
data = env.reset()
env.rand_step(data)
更改环境配置¶
env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env.reset()
env.close()
del env
from torchrl.envs import (
Compose,
NoopResetEnv,
ObservationNorm,
ToTensorImage,
TransformedEnv,
)
base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
环境转换¶
转换类似于 Gym 包装器,但 API 更接近 torchvision 的 torch.distributions
转换。有多种 转换 可供选择。
from torchrl.envs import (
Compose,
NoopResetEnv,
ObservationNorm,
StepCounter,
ToTensorImage,
TransformedEnv,
)
base_env = GymEnv("HalfCheetah-v4", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env = env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
env.reset()
print("env: ", env)
print("last transform parent: ", env.transform[2].parent)
向量化环境¶
向量化/并行环境可以提供显著的速度提升。
from torchrl.envs import ParallelEnv
def make_env():
# You can control whether to use gym or gymnasium for your env
with set_gym_backend("gym"):
return GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
base_env = ParallelEnv(
4,
make_env,
mp_start_method="fork", # This will break on Windows machines! Remove and decorate with if __name__ == "__main__"
)
env = TransformedEnv(
base_env, Compose(StepCounter(), ToTensorImage())
) # applies transforms on batch of envs
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
env.reset()
print(env.action_spec)
env.close()
del env
模块¶
库中可以找到多个 模块 (实用程序、模型和包装器)。
模型¶
MLP 模型示例
from torch import nn
from torchrl.modules import ConvNet, MLP
from torchrl.modules.models.utils import SquashDims
net = MLP(num_cells=[32, 64], out_features=4, activation_class=nn.ELU)
print(net)
print(net(torch.randn(10, 3)).shape)
CNN 模型示例
cnn = ConvNet(
num_cells=[32, 64],
kernel_sizes=[8, 4],
strides=[2, 1],
aggregator_class=SquashDims,
)
print(cnn)
print(cnn(torch.randn(10, 3, 32, 32)).shape) # last tensor is squashed
TensorDictModules¶
一些模块 专门设计用于处理 tensordict 输入。
from tensordict.nn import TensorDictModule
data = TensorDict({"key1": torch.randn(10, 3)}, batch_size=[10])
module = nn.Linear(3, 4)
td_module = TensorDictModule(module, in_keys=["key1"], out_keys=["key2"])
td_module(data)
print(data)
模块序列¶
通过 TensorDictSequential
,可以轻松创建模块序列。
from tensordict.nn import TensorDictSequential
backbone_module = nn.Linear(5, 3)
backbone = TensorDictModule(
backbone_module, in_keys=["observation"], out_keys=["hidden"]
)
actor_module = nn.Linear(3, 4)
actor = TensorDictModule(actor_module, in_keys=["hidden"], out_keys=["action"])
value_module = MLP(out_features=1, num_cells=[4, 5])
value = TensorDictModule(value_module, in_keys=["hidden", "action"], out_keys=["value"])
sequence = TensorDictSequential(backbone, actor, value)
print(sequence)
print(sequence.in_keys, sequence.out_keys)
data = TensorDict(
{"observation": torch.randn(3, 5)},
[3],
)
backbone(data)
actor(data)
value(data)
data = TensorDict(
{"observation": torch.randn(3, 5)},
[3],
)
sequence(data)
print(data)
函数式编程 (集成 / Meta-RL)¶
函数式调用从未如此简单。使用 from_module()
提取参数,并使用 to_module()
替换它们。
from tensordict import from_module
params = from_module(sequence)
print("extracted params", params)
使用 tensordict 进行函数式调用
with params.to_module(sequence):
data = sequence(data)
VMAP¶
快速执行相似架构的多个副本对于快速训练模型至关重要。vmap()
正是为了实现这一点而量身定制的。
专用类¶
TorchRL 还提供了一些对输出值进行检查的专用模块。
torch.manual_seed(0)
from torchrl.data import Bounded
from torchrl.modules import SafeModule
spec = Bounded(-torch.ones(3), torch.ones(3))
base_module = nn.Linear(5, 3)
module = SafeModule(
module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True
)
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
module(data)["action"]
data = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[])
module(data)["action"] # safe=True projects the result within the set
Actor
类具有预定义的输出键 ("action"
)。
from torchrl.modules import Actor
base_module = nn.Linear(5, 3)
actor = Actor(base_module, in_keys=["obs"])
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
actor(data) # action is the default value
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
借助 tensordict.nn
API,使用概率模型也变得更加容易。
from torchrl.modules import NormalParamExtractor, TanhNormal
td = TensorDict({"input": torch.randn(3, 5)}, [3])
net = nn.Sequential(
nn.Linear(5, 4), NormalParamExtractor()
) # splits the output in loc and scale
module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
module,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
return_log_prob=False,
),
)
td_module(td)
print(td)
# returning the log-probability
td = TensorDict({"input": torch.randn(3, 5)}, [3])
td_module = ProbabilisticTensorDictSequential(
module,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
return_log_prob=True,
),
)
td_module(td)
print(td)
通过上下文管理器 set_exploration_type
可以实现对随机性和采样策略的控制。
from torchrl.envs.utils import ExplorationType, set_exploration_type
td = TensorDict({"input": torch.randn(3, 5)}, [3])
torch.manual_seed(0)
with set_exploration_type(ExplorationType.RANDOM):
td_module(td)
print("random:", td["action"])
with set_exploration_type(ExplorationType.DETERMINISTIC):
td_module(td)
print("mode:", td["action"])
使用环境和模块¶
让我们看看如何结合使用环境和模块。
from torchrl.envs.utils import step_mdp
env = GymEnv("Pendulum-v1")
action_spec = env.action_spec
actor_module = nn.Linear(3, 1)
actor = SafeModule(
actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"]
)
torch.manual_seed(0)
env.set_seed(0)
max_steps = 100
data = env.reset()
data_stack = TensorDict(batch_size=[max_steps])
for i in range(max_steps):
actor(data)
data_stack[i] = env.step(data)
if data["done"].any():
break
data = step_mdp(data) # roughly equivalent to obs = next_obs
tensordicts_prealloc = data_stack.clone()
print("total steps:", i)
print(data_stack)
# equivalent
torch.manual_seed(0)
env.set_seed(0)
max_steps = 100
data = env.reset()
data_stack = []
for _ in range(max_steps):
actor(data)
data_stack.append(env.step(data))
if data["done"].any():
break
data = step_mdp(data) # roughly equivalent to obs = next_obs
tensordicts_stack = torch.stack(data_stack, 0)
print("total steps:", i)
print(tensordicts_stack)
(tensordicts_stack == tensordicts_prealloc).all()
torch.manual_seed(0)
env.set_seed(0)
tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps)
tensordict_rollout
(tensordict_rollout == tensordicts_prealloc).all()
from tensordict.nn import TensorDictModule
收集器¶
我们还提供了一套 数据收集器,它们可以自动收集每批所需数量的帧。它们适用于从单节点、单工作进程到多节点、多工作进程的各种设置。
from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector
from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs.libs.gym import GymEnv
EnvCreator 确保我们可以从一个进程将 lambda 函数发送到另一个进程。我们使用 SerialEnv
以简化 (单工作进程),但对于较大的任务,ParallelEnv
(多工作进程) 会更合适。
注意
多进程环境和多进程收集器可以结合使用!
parallel_env = SerialEnv(
3,
EnvCreator(lambda: GymEnv("Pendulum-v1")),
)
create_env_fn = [parallel_env, parallel_env]
actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])
同步多进程数据收集器¶
devices = ["cpu", "cpu"]
collector = MultiSyncDataCollector(
create_env_fn=create_env_fn, # either a list of functions or a ParallelEnv
policy=actor,
total_frames=240,
max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early
frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector)
device=devices,
)
for i, d in enumerate(collector):
if i == 0:
print(d) # trajectories are split automatically in [6 workers x 10 steps]
collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
异步多进程数据收集器¶
此类允许您在模型训练时收集数据。这在离策略设置中尤其有用,因为它将推理与模型训练解耦。数据以先到先得的方式交付 (工作进程将排队等待其结果)。
collector = MultiaSyncDataCollector(
create_env_fn=create_env_fn, # either a list of functions or a ParallelEnv
policy=actor,
total_frames=240,
max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early
frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector)
device=devices,
)
for i, d in enumerate(collector):
if i == 0:
print(d) # trajectories are split automatically in [6 workers x 10 steps]
collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
del create_env_fn
del parallel_env
目标¶
目标 是编写新算法时的主要入口点。
from torchrl.objectives import DDPGLoss
actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])
class ConcatModule(nn.Linear):
def forward(self, obs, action):
return super().forward(torch.cat([obs, action], -1))
value_module = ConcatModule(4, 1)
value = TensorDictModule(
value_module, in_keys=["observation", "action"], out_keys=["state_action_value"]
)
loss_fn = DDPGLoss(actor, value)
loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99)
data = TensorDict(
{
"observation": torch.randn(10, 3),
"next": {
"observation": torch.randn(10, 3),
"reward": torch.randn(10, 1),
"done": torch.zeros(10, 1, dtype=torch.bool),
},
"action": torch.randn(10, 1),
},
batch_size=[10],
device="cpu",
)
loss_td = loss_fn(data)
print(loss_td)
print(data)
安装库¶
该库已在 PyPI 上发布:pip install torchrl 更多信息请参阅 README。
贡献¶
我们正在积极寻找贡献者和早期用户。如果您正在从事 RL 工作 (或者只是好奇),请尝试一下!给我们反馈:TorchRL 的成功取决于它在多大程度上满足研究人员的需求。为此,我们需要他们的投入!由于该库尚处于起步阶段,现在是塑造它的绝佳时机!
更多信息请参阅 贡献指南。