注意
转到末尾 下载完整的示例代码。
TorchRL 教程:强化学习 (PPO)#
创建于:2023年3月15日 | 最后更新:2025年3月20日 | 最后验证:2024年11月5日
本教程演示了如何使用 PyTorch 和 torchrl
来训练参数策略网络以解决 OpenAI-Gym/Farama-Gymnasium 控制库 中的倒立摆任务。

倒立摆#
主要学习内容
如何创建 TorchRL 中的环境,转换其输出,以及从该环境中收集数据;
如何使用
TensorDict
让类之间进行通信;使用 TorchRL 构建训练循环的基础知识
如何计算策略梯度方法的优势信号;
如何创建随机策略(使用概率神经网络);
如何创建动态回放缓冲区并从中进行无重复采样。
我们将介绍 TorchRL 的六个关键组成部分
如果您在 Google Colab 中运行此代码,请确保安装以下依赖项
!pip3 install torchrl
!pip3 install gym[mujoco]
!pip3 install tqdm
Proximal Policy Optimization (PPO) 是一种策略梯度算法,它收集一批数据并直接使用这些数据来训练策略,以在某些近邻约束下最大化预期回报。您可以将其视为 REINFORCE(基础策略优化算法)的复杂版本。有关更多信息,请参阅 Proximal Policy Optimization Algorithms 论文。
PPO 通常被认为是在线、在策略强化算法的一种快速有效的方法。TorchRL 提供了一个损失模块,可以完成所有工作,因此您可以依赖此实现,专注于解决您的问题,而不是在每次想要训练策略时重新发明轮子。
为完整起见,这里是损失计算的简要概述,尽管这已由我们的 ClipPPOLoss
模块处理——算法如下:1. 我们将通过让策略在环境中运行给定数量的步骤来采样一批数据。2. 然后,我们将使用裁剪版本的 REINFORCE 损失,通过随机子批次对这批数据执行给定数量的优化步骤。3. 裁剪将对我们的损失设置一个悲观的界限:相比于更高的回报估计,将优先考虑更低的回报估计。损失的精确公式是
该损失有两个组成部分:在最小运算符的第一部分,我们计算了 REINFORCE 损失的权重重要性版本(例如,我们已根据当前策略配置滞后于数据收集所用配置的事实进行了修正的 REINFORCE 损失)。最小运算符的第二部分是类似的损失,其中我们将比率裁剪在它们超过或低于给定阈值对时。
此损失确保无论优势是正数还是负数,都会抑制那些会导致与先前配置相比发生重大变化的策略更新。
本教程结构如下
首先,我们将定义一组用于训练的超参数。
接下来,我们将专注于使用 TorchRL 的包装器和转换器来创建我们的环境或模拟器。
接下来,我们将设计策略网络和价值模型,这对于损失函数是必不可少的。这些模块将用于配置我们的损失模块。
接下来,我们将创建回放缓冲区和数据加载器。
最后,我们将运行我们的训练循环并分析结果。
在本教程中,我们将使用 tensordict
库。TensorDict
是 TorchRL 的通用语言:它帮助我们抽象出模块读取和写入的内容,并让我们更少地关注具体数据描述,而更多地关注算法本身。
import warnings
warnings.filterwarnings("ignore")
from torch import multiprocessing
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter,
TransformedEnv)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tqdm import tqdm
定义超参数#
我们为算法设置超参数。根据可用的资源,可以选择在 GPU 或其他设备上执行策略。`frame_skip` 将控制一个动作执行多少帧。其余计数帧的参数必须根据此值进行更正(因为一个环境步骤实际上将返回 `frame_skip` 帧)。
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
num_cells = 256 # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0
数据收集参数#
在收集数据时,我们可以通过定义 `frames_per_batch` 参数来选择每个批次的规模。我们还将定义允许使用的帧数(例如,与模拟器的交互次数)。通常,RL 算法的目标是学习尽可能快地解决任务(以环境交互次数计):`total_frames` 越低越好。
frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 50_000
PPO 参数#
在每次数据收集(或批次收集)时,我们将在一定数量的 *epoch* 上运行优化,每次都在嵌套训练循环中消耗我们刚刚获取的全部数据。这里的 `sub_batch_size` 不同于上面的 `frames_per_batch`:请记住,我们正在处理来自收集器的“一批数据”,其大小由 `frames_per_batch` 定义,并且在内部训练循环中,我们将进一步将其拆分为更小的子批次。这些子批次的数量由 `sub_batch_size` 控制。
sub_batch_size = 64 # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10 # optimization steps per batch of data collected
clip_epsilon = (
0.2 # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4
定义环境#
在 RL 中,*环境* 通常是我们对模拟器或控制系统的称呼。各种库都提供用于强化学习的模拟环境,包括 Gymnasium(以前称为 OpenAI Gym)、DeepMind 控制套件以及许多其他库。作为通用库,TorchRL 的目标是为大量 RL 模拟器提供可互换的接口,让您可以轻松地将一个环境替换为另一个环境。例如,使用几个字符就可以创建一个包装好的 gym 环境
Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.org.cn/introduction/migration_guide/ for additional information.
此代码中有几点需要注意:首先,我们通过调用 `GymEnv` 包装器创建了环境。如果传递了额外的关键字参数,它们将被传输到 `gym.make` 方法,从而涵盖了最常见的环境构建命令。或者,也可以直接使用 `gym.make(env_name, **kwargs)` 创建一个 gym 环境,并将其包装在 `GymWrapper` 类中。
还有 `device` 参数:对于 gym,这仅控制输入动作和观察到的状态存储的设备,但执行始终在 CPU 上进行。原因很简单,gym 不支持设备上执行,除非另有说明。对于其他库,我们可以控制执行设备,并且尽可能在存储和执行后端方面保持一致。
转换#
我们将为环境添加一些转换来准备策略的数据。在 Gym 中,这通常是通过包装器实现的。TorchRL 采用不同的方法,通过使用转换器,这更类似于其他 PyTorch 领域库。要向环境添加转换器,只需将其包装在 `TransformedEnv` 实例中,然后将转换器序列附加到它。转换后的环境将继承被包装环境的设备和元数据,并根据其包含的转换器序列来转换这些元数据。
标准化#
第一个要编码的是标准化转换。作为经验法则,最好让数据大致匹配单位高斯分布:为此,我们将使用随机步骤在环境中运行一定数量的步骤,并计算这些观测值的汇总统计数据。
我们将附加另外两个转换器:`DoubleToFloat` 转换器将双精度条目转换为单精度数字,准备好被策略读取。`StepCounter` 转换器将用于计算环境终止前的步数。我们将使用此度量作为补充性能度量。
正如我们稍后将看到的,TorchRL 的许多类都依赖于 `TensorDict` 进行通信。您可以将其视为一个具有额外张量功能的 Python 字典。实际上,这意味着我们将使用的许多模块需要被告知读取哪个键(`in_keys`)以及写入哪个键(`out_keys`)在它们将接收到的 `tensordict` 中。通常,如果省略 `out_keys`,则假定 `in_keys` 条目将被就地更新。对于我们的转换器,我们唯一感兴趣的条目是 `observation`,我们的转换器层将被告知修改此条目,并且仅此条目。
env = TransformedEnv(
base_env,
Compose(
# normalize observations
ObservationNorm(in_keys=["observation"]),
DoubleToFloat(),
StepCounter(),
),
)
正如您可能注意到的,我们创建了一个标准化层,但没有设置其标准化参数。为此,`ObservationNorm` 可以自动收集我们环境的汇总统计数据
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
`ObservationNorm` 转换器现在已填充了位置和比例,将用于标准化数据。
让我们对汇总统计数据的形状进行一些健全性检查
print("normalization constant shape:", env.transform[0].loc.shape)
normalization constant shape: torch.Size([11])
环境不仅由其模拟器和转换器定义,还由一系列描述其执行过程中预期什么的元数据定义。为了效率起见,TorchRL 在环境规范方面非常严格,但您可以轻松检查您的环境规范是否足够。在我们的示例中,`GymWrapper` 和继承自它的 `GymEnv` 已经负责为您的环境设置正确的规范,因此您不必关心这一点。
尽管如此,让我们通过查看其规范来具体看一下我们的转换环境。有三个规范需要查看:`observation_spec`,它定义了在环境中执行动作时可以预期什么;`reward_spec`,它指示奖励域;最后是 `input_spec`(其中包含 `action_spec`),它代表环境执行单个步骤所需的一切。
print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("input_spec:", env.input_spec)
print("action_spec (as defined by input_spec):", env.action_spec)
observation_spec: Composite(
observation: UnboundedContinuous(
shape=torch.Size([11]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous),
step_count: BoundedDiscrete(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
device=cpu,
dtype=torch.int64,
domain=discrete),
device=cpu,
shape=torch.Size([]),
data_cls=None)
reward_spec: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
input_spec: Composite(
full_state_spec: Composite(
step_count: BoundedDiscrete(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
device=cpu,
dtype=torch.int64,
domain=discrete),
device=cpu,
shape=torch.Size([]),
data_cls=None),
full_action_spec: Composite(
action: BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous),
device=cpu,
shape=torch.Size([]),
data_cls=None),
device=cpu,
shape=torch.Size([]),
data_cls=None)
action_spec (as defined by input_spec): BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
`check_env_specs()` 函数运行一个小的滚动,并将其输出与环境规范进行比较。如果没有引发错误,我们可以确信规范已正确定义。
2025-08-07 18:26:50,027 [torchrl][INFO] check_env_specs succeeded! [END]
为了好玩,让我们看看简单的随机滚动是什么样的。您可以调用 `env.rollout(n_steps)` 并获得环境输入和输出外观的概述。动作将自动从动作规范域中抽取,因此您不必担心设计一个随机采样器。
通常,在每一步,RL 环境都会接收一个动作作为输入,并输出一个观测、一个奖励和一个完成状态。观测可能是复合的,这意味着它可能由一个以上的张量组成。这对 TorchRL 来说不是问题,因为整个观测集会自动打包到输出 `TensorDict` 中。在执行给定步数的滚动(例如,一系列环境步骤和随机动作生成)后,我们将检索一个 `TensorDict` 实例,其形状与此轨迹长度匹配。
rollout = env.rollout(3)
print("rollout of three steps:", rollout)
print("Shape of the rollout TensorDict:", rollout.batch_size)
rollout of three steps: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False)
Shape of the rollout TensorDict: torch.Size([3])
我们的滚动数据形状为 `torch.Size([3])`,与我们运行它的步数匹配。`next` 条目指向当前步骤之后的数据。在大多数情况下,时间 `t` 的 `next` 数据与 `t+1` 的数据匹配,但在使用某些特定转换(例如,多步)时,情况可能并非如此。
策略#
PPO 使用随机策略来处理探索。这意味着我们的神经网络将不得不输出一个分布的参数,而不是对应于所采取动作的单个值。
由于数据是连续的,我们使用 Tanh-Normal 分布来尊重动作空间边界。TorchRL 提供此类分布,我们唯一需要关心的是构建一个输出正确数量参数的神经网络,以便策略可以与之配合(一个位置,或均值,以及一个尺度)。
这里带来的唯一额外困难是分割我们的输出为两部分,并将第二部分映射到一个严格正的空间。
我们分三个步骤设计策略
定义一个神经网络 `D_obs` -> `2 * D_action`。确实,我们的 `loc`(mu)和 `scale`(sigma)都具有 `D_action` 维度。
附加一个 `NormalParamExtractor` 来提取一个位置和一个尺度(例如,将输入分割成两部分,并对尺度参数应用正向转换)。
创建一个概率 `TensorDictModule`,它可以生成此分布并从中采样。
actor_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
NormalParamExtractor(),
)
为了让策略能够通过 `tensordict` 数据载体与环境“交谈”,我们将 `nn.Module` 包装在一个 `TensorDictModule` 中。此类将简单地读取提供的 `in_keys`,并在注册的 `out_keys` 处就地写入输出。
policy_module = TensorDictModule(
actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)
我们现在需要根据我们正态分布的位置和尺度构建一个分布。为此,我们指示 `ProbabilisticActor` 类构建一个 `TanhNormal`,并将最小和最大值提供给此分布,我们从环境规范中获取这些值。
`in_keys` 的名称(以及因此 `TensorDictModule` 的 `out_keys` 的名称)不能随意设置,因为 `TanhNormal` 分布构造函数将期望 `loc` 和 `scale` 关键字参数。话虽如此,`ProbabilisticActor` 也接受 `Dict[str, str]` 类型的 `in_keys`,其中键值对表示将为每个要使用的关键字参数使用的 `in_key` 字符串。
policy_module = ProbabilisticActor(
module=policy_module,
spec=env.action_spec,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
},
return_log_prob=True,
# we'll need the log-prob for the numerator of the importance weights
)
价值网络#
价值网络是 PPO 算法的关键组成部分,尽管它不会在推理时使用。此模块将读取观测值,并返回对后续轨迹的折扣回报的估计。这使我们能够通过依赖在训练过程中即时学习的效用估计来摊销学习。我们的价值网络与策略具有相同的结构,但为简单起见,我们为其分配了自己的一组参数。
value_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(1, device=device),
)
value_module = ValueOperator(
module=value_net,
in_keys=["observation"],
)
让我们试用我们的策略和价值模块。如前所述,`TensorDictModule` 的使用使得可以直接读取环境的输出来运行这些模块,因为它们知道要读取什么信息以及在何处写入。
print("Running policy:", policy_module(env.reset()))
print("Running value:", value_module(env.reset()))
Running policy: TensorDict(
fields={
action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
action_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
loc: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=cpu,
is_shared=False)
Running value: TensorDict(
fields={
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
state_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=cpu,
is_shared=False)
数据收集器#
TorchRL 提供了一套 `DataCollector` 类。简而言之,这些类执行三个操作:重置环境,根据最新观测值计算动作,在环境中执行一个步骤,并重复最后两个步骤,直到环境发出停止信号(或达到完成状态)。
它们允许您控制每次迭代收集多少帧(通过 `frames_per_batch` 参数),何时重置环境(通过 `max_frames_per_traj` 参数),策略应在哪个 `device` 上执行,等等。它们还被设计为与批量和多进程环境高效配合。
最简单的数据收集器是 `SyncDataCollector`:它是一个迭代器,您可以使用它来获取数据批次,直到收集完总帧数(`total_frames`)为止。其他数据收集器(`MultiSyncDataCollector` 和 `MultiaSyncDataCollector`)将在同步和异步模式下在一组多进程工作器上执行相同的操作。
与之前的策略和环境一样,数据收集器将返回 `TensorDict` 实例,其总元素数量将匹配 `frames_per_batch`。使用 `TensorDict` 将数据传递给训练循环允许您编写 100% 与滚动内容实际细节无关的数据加载管道。
collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=False,
device=device,
)
回放缓冲区#
回放缓冲区是离策略 RL 算法的常见构建模块。在策略环境中,每当收集一批数据时,回放缓冲区就会被重新填充,并且其数据会在一定数量的 epoch 中被重复消耗。
TorchRL 的回放缓冲区使用一个通用容器 `ReplayBuffer` 构建,该容器以缓冲区组件为参数:存储、写入器、采样器以及可能的转换器。只有存储(指示回放缓冲区容量)是必需的。我们还指定一个无重复采样器,以避免在一个 epoch 中多次采样同一项。为 PPO 使用回放缓冲区不是必需的,我们可以简单地从收集的批次中对子批次进行采样,但使用这些类可以让我们轻松地以可重现的方式构建内部训练循环。
replay_buffer = ReplayBuffer(
storage=LazyTensorStorage(max_size=frames_per_batch),
sampler=SamplerWithoutReplacement(),
)
损失函数#
PPO 损失可以直接从 TorchRL 导入,方便起见,使用 `ClipPPOLoss` 类。这是使用 PPO 最简单的方法:它隐藏了 PPO 的数学运算和相关的控制流程。
PPO 需要计算一些“优势估计”。简而言之,优势是一个值,它反映了在处理偏差/方差权衡时的预期回报值。要计算优势,只需 (1) 构建优势模块,该模块利用我们的值运算符,以及 (2) 在每个 epoch 之前将每个数据批次通过它。GAE 模块将使用新的 `"advantage"` 和 `"value_target"` 条目更新输入 `tensordict`。`"value_target"` 是一个无梯度张量,表示价值网络应该表示的经验值(相对于输入观测值)。这两个都将由 `ClipPPOLoss` 用于返回策略和价值损失。
advantage_module = GAE(
gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True, device=device,
)
loss_module = ClipPPOLoss(
actor_network=policy_module,
critic_network=value_module,
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
entropy_coef=entropy_eps,
# these keys match by default but we set this for completeness
critic_coef=1.0,
loss_critic_type="smooth_l1",
)
optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optim, total_frames // frames_per_batch, 0.0
)
/usr/local/lib/python3.10/dist-packages/torchrl/objectives/ppo.py:384: DeprecationWarning:
'critic_coef' is deprecated and will be removed in torchrl v0.11. Please use 'critic_coeff' instead.
/usr/local/lib/python3.10/dist-packages/torchrl/objectives/ppo.py:450: DeprecationWarning:
'entropy_coef' is deprecated and will be removed in torchrl v0.11. Please use 'entropy_coeff' instead.
训练循环#
现在我们有了编写训练循环所需的所有组件。步骤包括
收集数据
计算优势
循环处理收集到的数据以计算损失值
反向传播
优化
重复
重复
重复
logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""
# We iterate over the collector until it reaches the total number of frames it was
# designed to collect:
for i, tensordict_data in enumerate(collector):
# we now have a batch of data to work with. Let's learn something from it.
for _ in range(num_epochs):
# We'll need an "advantage" signal to make PPO work.
# We re-compute it at each epoch as its value depends on the value
# network which is updated in the inner loop.
advantage_module(tensordict_data)
data_view = tensordict_data.reshape(-1)
replay_buffer.extend(data_view.cpu())
for _ in range(frames_per_batch // sub_batch_size):
subdata = replay_buffer.sample(sub_batch_size)
loss_vals = loss_module(subdata.to(device))
loss_value = (
loss_vals["loss_objective"]
+ loss_vals["loss_critic"]
+ loss_vals["loss_entropy"]
)
# Optimization: backward, grad clipping and optimization step
loss_value.backward()
# this is not strictly mandatory but it's good practice to keep
# your gradient norm bounded
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
optim.step()
optim.zero_grad()
logs["reward"].append(tensordict_data["next", "reward"].mean().item())
pbar.update(tensordict_data.numel())
cum_reward_str = (
f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
)
logs["step_count"].append(tensordict_data["step_count"].max().item())
stepcount_str = f"step count (max): {logs['step_count'][-1]}"
logs["lr"].append(optim.param_groups[0]["lr"])
lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
if i % 10 == 0:
# We evaluate the policy once every 10 batches of data.
# Evaluation is rather simple: execute the policy without exploration
# (take the expected value of the action distribution) for a given
# number of steps (1000, which is our ``env`` horizon).
# The ``rollout`` method of the ``env`` can take a policy as argument:
# it will then execute this policy at each step.
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
# execute a rollout with the trained policy
eval_rollout = env.rollout(1000, policy_module)
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
logs["eval reward (sum)"].append(
eval_rollout["next", "reward"].sum().item()
)
logs["eval step_count"].append(eval_rollout["step_count"].max().item())
eval_str = (
f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
f"eval step-count: {logs['eval step_count'][-1]}"
)
del eval_rollout
pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))
# We're also using a learning rate scheduler. Like the gradient clipping,
# this is a nice-to-have but nothing necessary for PPO to work.
scheduler.step()
0%| | 0/50000 [00:00<?, ?it/s]
2%|▏ | 1000/50000 [00:03<02:42, 301.18it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.0940 (init= 9.0940), step count (max): 15, lr policy: 0.0003: 2%|▏ | 1000/50000 [00:03<02:42, 301.18it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.0940 (init= 9.0940), step count (max): 15, lr policy: 0.0003: 4%|▍ | 2000/50000 [00:06<02:30, 319.34it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.1355 (init= 9.0940), step count (max): 17, lr policy: 0.0003: 4%|▍ | 2000/50000 [00:06<02:30, 319.34it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.1355 (init= 9.0940), step count (max): 17, lr policy: 0.0003: 6%|▌ | 3000/50000 [00:09<02:23, 326.85it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.1644 (init= 9.0940), step count (max): 22, lr policy: 0.0003: 6%|▌ | 3000/50000 [00:09<02:23, 326.85it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.1644 (init= 9.0940), step count (max): 22, lr policy: 0.0003: 8%|▊ | 4000/50000 [00:12<02:19, 330.93it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.1695 (init= 9.0940), step count (max): 17, lr policy: 0.0003: 8%|▊ | 4000/50000 [00:12<02:19, 330.93it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.1695 (init= 9.0940), step count (max): 17, lr policy: 0.0003: 10%|█ | 5000/50000 [00:15<02:14, 334.36it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2014 (init= 9.0940), step count (max): 25, lr policy: 0.0003: 10%|█ | 5000/50000 [00:15<02:14, 334.36it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2014 (init= 9.0940), step count (max): 25, lr policy: 0.0003: 12%|█▏ | 6000/50000 [00:18<02:10, 336.65it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2151 (init= 9.0940), step count (max): 25, lr policy: 0.0003: 12%|█▏ | 6000/50000 [00:18<02:10, 336.65it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2151 (init= 9.0940), step count (max): 25, lr policy: 0.0003: 14%|█▍ | 7000/50000 [00:21<02:07, 338.36it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2083 (init= 9.0940), step count (max): 36, lr policy: 0.0003: 14%|█▍ | 7000/50000 [00:21<02:07, 338.36it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2083 (init= 9.0940), step count (max): 36, lr policy: 0.0003: 16%|█▌ | 8000/50000 [00:24<02:05, 334.93it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2328 (init= 9.0940), step count (max): 33, lr policy: 0.0003: 16%|█▌ | 8000/50000 [00:24<02:05, 334.93it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2328 (init= 9.0940), step count (max): 33, lr policy: 0.0003: 18%|█▊ | 9000/50000 [00:26<02:01, 338.09it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2450 (init= 9.0940), step count (max): 42, lr policy: 0.0003: 18%|█▊ | 9000/50000 [00:26<02:01, 338.09it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2450 (init= 9.0940), step count (max): 42, lr policy: 0.0003: 20%|██ | 10000/50000 [00:29<01:58, 337.89it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2488 (init= 9.0940), step count (max): 52, lr policy: 0.0003: 20%|██ | 10000/50000 [00:29<01:58, 337.89it/s]
eval cumulative reward: 138.6874 (init: 138.6874), eval step-count: 14, average reward= 9.2488 (init= 9.0940), step count (max): 52, lr policy: 0.0003: 22%|██▏ | 11000/50000 [00:32<01:55, 338.58it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2468 (init= 9.0940), step count (max): 46, lr policy: 0.0003: 22%|██▏ | 11000/50000 [00:32<01:55, 338.58it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2468 (init= 9.0940), step count (max): 46, lr policy: 0.0003: 24%|██▍ | 12000/50000 [00:35<01:52, 338.24it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2722 (init= 9.0940), step count (max): 52, lr policy: 0.0003: 24%|██▍ | 12000/50000 [00:35<01:52, 338.24it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2722 (init= 9.0940), step count (max): 52, lr policy: 0.0003: 26%|██▌ | 13000/50000 [00:38<01:48, 340.52it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2725 (init= 9.0940), step count (max): 92, lr policy: 0.0003: 26%|██▌ | 13000/50000 [00:38<01:48, 340.52it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2725 (init= 9.0940), step count (max): 92, lr policy: 0.0003: 28%|██▊ | 14000/50000 [00:41<01:45, 342.36it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2782 (init= 9.0940), step count (max): 74, lr policy: 0.0003: 28%|██▊ | 14000/50000 [00:41<01:45, 342.36it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2782 (init= 9.0940), step count (max): 74, lr policy: 0.0003: 30%|███ | 15000/50000 [00:44<01:43, 337.91it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2688 (init= 9.0940), step count (max): 60, lr policy: 0.0002: 30%|███ | 15000/50000 [00:44<01:43, 337.91it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2688 (init= 9.0940), step count (max): 60, lr policy: 0.0002: 32%|███▏ | 16000/50000 [00:47<01:39, 340.42it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2747 (init= 9.0940), step count (max): 64, lr policy: 0.0002: 32%|███▏ | 16000/50000 [00:47<01:39, 340.42it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2747 (init= 9.0940), step count (max): 64, lr policy: 0.0002: 34%|███▍ | 17000/50000 [00:50<01:36, 342.17it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2801 (init= 9.0940), step count (max): 63, lr policy: 0.0002: 34%|███▍ | 17000/50000 [00:50<01:36, 342.17it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2801 (init= 9.0940), step count (max): 63, lr policy: 0.0002: 36%|███▌ | 18000/50000 [00:53<01:33, 343.63it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2818 (init= 9.0940), step count (max): 64, lr policy: 0.0002: 36%|███▌ | 18000/50000 [00:53<01:33, 343.63it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2818 (init= 9.0940), step count (max): 64, lr policy: 0.0002: 38%|███▊ | 19000/50000 [00:56<01:29, 344.60it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2908 (init= 9.0940), step count (max): 104, lr policy: 0.0002: 38%|███▊ | 19000/50000 [00:56<01:29, 344.60it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2908 (init= 9.0940), step count (max): 104, lr policy: 0.0002: 40%|████ | 20000/50000 [00:59<01:26, 345.25it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2901 (init= 9.0940), step count (max): 60, lr policy: 0.0002: 40%|████ | 20000/50000 [00:59<01:26, 345.25it/s]
eval cumulative reward: 260.0608 (init: 138.6874), eval step-count: 27, average reward= 9.2901 (init= 9.0940), step count (max): 60, lr policy: 0.0002: 42%|████▏ | 21000/50000 [01:01<01:23, 345.50it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2838 (init= 9.0940), step count (max): 56, lr policy: 0.0002: 42%|████▏ | 21000/50000 [01:02<01:23, 345.50it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2838 (init= 9.0940), step count (max): 56, lr policy: 0.0002: 44%|████▍ | 22000/50000 [01:05<01:22, 338.57it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2823 (init= 9.0940), step count (max): 50, lr policy: 0.0002: 44%|████▍ | 22000/50000 [01:05<01:22, 338.57it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2823 (init= 9.0940), step count (max): 50, lr policy: 0.0002: 46%|████▌ | 23000/50000 [01:07<01:19, 341.05it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2890 (init= 9.0940), step count (max): 63, lr policy: 0.0002: 46%|████▌ | 23000/50000 [01:07<01:19, 341.05it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2890 (init= 9.0940), step count (max): 63, lr policy: 0.0002: 48%|████▊ | 24000/50000 [01:10<01:15, 342.77it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2722 (init= 9.0940), step count (max): 51, lr policy: 0.0002: 48%|████▊ | 24000/50000 [01:10<01:15, 342.77it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2722 (init= 9.0940), step count (max): 51, lr policy: 0.0002: 50%|█████ | 25000/50000 [01:13<01:12, 343.10it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2801 (init= 9.0940), step count (max): 55, lr policy: 0.0002: 50%|█████ | 25000/50000 [01:13<01:12, 343.10it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2801 (init= 9.0940), step count (max): 55, lr policy: 0.0002: 52%|█████▏ | 26000/50000 [01:16<01:09, 343.39it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2827 (init= 9.0940), step count (max): 84, lr policy: 0.0001: 52%|█████▏ | 26000/50000 [01:16<01:09, 343.39it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2827 (init= 9.0940), step count (max): 84, lr policy: 0.0001: 54%|█████▍ | 27000/50000 [01:19<01:07, 342.50it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2908 (init= 9.0940), step count (max): 58, lr policy: 0.0001: 54%|█████▍ | 27000/50000 [01:19<01:07, 342.50it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2908 (init= 9.0940), step count (max): 58, lr policy: 0.0001: 56%|█████▌ | 28000/50000 [01:22<01:04, 341.62it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2935 (init= 9.0940), step count (max): 76, lr policy: 0.0001: 56%|█████▌ | 28000/50000 [01:22<01:04, 341.62it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2935 (init= 9.0940), step count (max): 76, lr policy: 0.0001: 58%|█████▊ | 29000/50000 [01:25<01:02, 338.01it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2957 (init= 9.0940), step count (max): 83, lr policy: 0.0001: 58%|█████▊ | 29000/50000 [01:25<01:02, 338.01it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.2957 (init= 9.0940), step count (max): 83, lr policy: 0.0001: 60%|██████ | 30000/50000 [01:28<00:58, 339.67it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.3056 (init= 9.0940), step count (max): 96, lr policy: 0.0001: 60%|██████ | 30000/50000 [01:28<00:58, 339.67it/s]
eval cumulative reward: 371.7483 (init: 138.6874), eval step-count: 39, average reward= 9.3056 (init= 9.0940), step count (max): 96, lr policy: 0.0001: 62%|██████▏ | 31000/50000 [01:31<00:55, 340.09it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3106 (init= 9.0940), step count (max): 89, lr policy: 0.0001: 62%|██████▏ | 31000/50000 [01:31<00:55, 340.09it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3106 (init= 9.0940), step count (max): 89, lr policy: 0.0001: 64%|██████▍ | 32000/50000 [01:34<00:53, 335.58it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3168 (init= 9.0940), step count (max): 121, lr policy: 0.0001: 64%|██████▍ | 32000/50000 [01:34<00:53, 335.58it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3168 (init= 9.0940), step count (max): 121, lr policy: 0.0001: 66%|██████▌ | 33000/50000 [01:37<00:50, 336.72it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3274 (init= 9.0940), step count (max): 141, lr policy: 0.0001: 66%|██████▌ | 33000/50000 [01:37<00:50, 336.72it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3274 (init= 9.0940), step count (max): 141, lr policy: 0.0001: 68%|██████▊ | 34000/50000 [01:40<00:47, 339.96it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3246 (init= 9.0940), step count (max): 128, lr policy: 0.0001: 68%|██████▊ | 34000/50000 [01:40<00:47, 339.96it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3246 (init= 9.0940), step count (max): 128, lr policy: 0.0001: 70%|███████ | 35000/50000 [01:43<00:44, 335.53it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3196 (init= 9.0940), step count (max): 135, lr policy: 0.0001: 70%|███████ | 35000/50000 [01:43<00:44, 335.53it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3196 (init= 9.0940), step count (max): 135, lr policy: 0.0001: 72%|███████▏ | 36000/50000 [01:46<00:41, 338.70it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3200 (init= 9.0940), step count (max): 107, lr policy: 0.0001: 72%|███████▏ | 36000/50000 [01:46<00:41, 338.70it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3200 (init= 9.0940), step count (max): 107, lr policy: 0.0001: 74%|███████▍ | 37000/50000 [01:49<00:38, 338.62it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3233 (init= 9.0940), step count (max): 105, lr policy: 0.0001: 74%|███████▍ | 37000/50000 [01:49<00:38, 338.62it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3233 (init= 9.0940), step count (max): 105, lr policy: 0.0001: 76%|███████▌ | 38000/50000 [01:52<00:35, 341.46it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3220 (init= 9.0940), step count (max): 131, lr policy: 0.0000: 76%|███████▌ | 38000/50000 [01:52<00:35, 341.46it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3220 (init= 9.0940), step count (max): 131, lr policy: 0.0000: 78%|███████▊ | 39000/50000 [01:54<00:32, 343.11it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3255 (init= 9.0940), step count (max): 164, lr policy: 0.0000: 78%|███████▊ | 39000/50000 [01:54<00:32, 343.11it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3255 (init= 9.0940), step count (max): 164, lr policy: 0.0000: 80%|████████ | 40000/50000 [01:57<00:29, 344.36it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3246 (init= 9.0940), step count (max): 133, lr policy: 0.0000: 80%|████████ | 40000/50000 [01:57<00:29, 344.36it/s]
eval cumulative reward: 718.0741 (init: 138.6874), eval step-count: 76, average reward= 9.3246 (init= 9.0940), step count (max): 133, lr policy: 0.0000: 82%|████████▏ | 41000/50000 [02:00<00:26, 338.47it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3249 (init= 9.0940), step count (max): 133, lr policy: 0.0000: 82%|████████▏ | 41000/50000 [02:01<00:26, 338.47it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3249 (init= 9.0940), step count (max): 133, lr policy: 0.0000: 84%|████████▍ | 42000/50000 [02:04<00:23, 333.55it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3279 (init= 9.0940), step count (max): 162, lr policy: 0.0000: 84%|████████▍ | 42000/50000 [02:04<00:23, 333.55it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3279 (init= 9.0940), step count (max): 162, lr policy: 0.0000: 86%|████████▌ | 43000/50000 [02:06<00:20, 335.05it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3246 (init= 9.0940), step count (max): 168, lr policy: 0.0000: 86%|████████▌ | 43000/50000 [02:06<00:20, 335.05it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3246 (init= 9.0940), step count (max): 168, lr policy: 0.0000: 88%|████████▊ | 44000/50000 [02:09<00:17, 336.91it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3277 (init= 9.0940), step count (max): 141, lr policy: 0.0000: 88%|████████▊ | 44000/50000 [02:09<00:17, 336.91it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3277 (init= 9.0940), step count (max): 141, lr policy: 0.0000: 90%|█████████ | 45000/50000 [02:12<00:14, 338.86it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3330 (init= 9.0940), step count (max): 153, lr policy: 0.0000: 90%|█████████ | 45000/50000 [02:12<00:14, 338.86it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3330 (init= 9.0940), step count (max): 153, lr policy: 0.0000: 92%|█████████▏| 46000/50000 [02:15<00:11, 340.52it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3319 (init= 9.0940), step count (max): 177, lr policy: 0.0000: 92%|█████████▏| 46000/50000 [02:15<00:11, 340.52it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3319 (init= 9.0940), step count (max): 177, lr policy: 0.0000: 94%|█████████▍| 47000/50000 [02:18<00:08, 340.38it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3320 (init= 9.0940), step count (max): 248, lr policy: 0.0000: 94%|█████████▍| 47000/50000 [02:18<00:08, 340.38it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3320 (init= 9.0940), step count (max): 248, lr policy: 0.0000: 96%|█████████▌| 48000/50000 [02:21<00:05, 337.91it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3339 (init= 9.0940), step count (max): 170, lr policy: 0.0000: 96%|█████████▌| 48000/50000 [02:21<00:05, 337.91it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3339 (init= 9.0940), step count (max): 170, lr policy: 0.0000: 98%|█████████▊| 49000/50000 [02:24<00:02, 340.73it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3344 (init= 9.0940), step count (max): 202, lr policy: 0.0000: 98%|█████████▊| 49000/50000 [02:24<00:02, 340.73it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3344 (init= 9.0940), step count (max): 202, lr policy: 0.0000: 100%|██████████| 50000/50000 [02:27<00:00, 343.08it/s]
eval cumulative reward: 886.8833 (init: 138.6874), eval step-count: 94, average reward= 9.3299 (init= 9.0940), step count (max): 155, lr policy: 0.0000: 100%|██████████| 50000/50000 [02:27<00:00, 343.08it/s]
结果#
在达到 100 万步上限之前,算法应该已经达到了 1000 步的最大步数,这是轨迹被截断前的最大步数。
plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()

结论和后续步骤#
在本教程中,我们学习了
如何使用 `torchrl` 创建和自定义环境;
如何编写模型和损失函数;
如何设置典型的训练循环。
如果您想进一步尝试本教程,可以进行以下修改
从效率角度来看,我们可以并行运行多个模拟来加速数据收集。有关更多信息,请参阅 `ParallelEnv`。
从日志记录角度来看,可以将 `torchrl.record.VideoRecorder` 转换器添加到环境,在请求渲染后,以获得倒立摆动作的视觉渲染。有关更多信息,请参阅 `torchrl.record`。
脚本总运行时间: (2 分钟 29.300 秒)