注意
跳转至页面底部下载完整示例代码。
使用 TorchRL 进行强化学习 (PPO) 教程#
创建日期:2023年3月15日 | 最后更新:2025年9月17日 | 最后验证:2024年11月5日
本教程演示了如何使用 PyTorch 和 torchrl 训练一个参数化策略网络,以解决 OpenAI-Gym/Farama-Gymnasium 控制库 中的倒立摆(Inverted Pendulum)任务。
倒立摆#
主要学习内容
如何创建 TorchRL 环境、转换其输出并从该环境收集数据;
如何使用
TensorDict让类之间进行通信;使用 TorchRL 构建训练循环的基础知识;
如何为策略梯度方法计算优势(advantage)信号;
如何使用概率神经网络创建随机策略;
如何创建一个动态回放缓冲区并进行无放回采样。
我们将涵盖 TorchRL 的六个关键组件:
如果您在 Google Colab 中运行此代码,请确保安装以下依赖项
!pip3 install torchrl
!pip3 install gym[mujoco]
!pip3 install tqdm
近端策略优化(PPO)是一种策略梯度算法,它收集并直接消耗一批数据,以在存在某些近端约束的情况下最大化预期回报来训练策略。您可以将其视为 REINFORCE(基础策略优化算法)的复杂版本。有关更多信息,请参阅 近端策略优化算法论文。
PPO 通常被认为是一种快速且高效的在线策略(on-policy)强化学习算法。TorchRL 提供了一个可以为您完成所有工作的损失模块,因此您可以依赖此实现,并专注于解决您的问题,而不是每次想要训练策略时都重复造轮子。
为了完整起见,这里简要概述了损失函数的计算逻辑(尽管这已由我们的 ClipPPOLoss 模块处理):算法工作原理如下:1. 我们通过在环境中运行策略一定的步数来采样一批数据。2. 然后,我们使用该批数据的随机子样本执行一定次数的优化步骤,并使用 REINFORCE 损失的截断版本。3. 截断将对我们的损失设置一个悲观的界限:相比于高回报估计,较低的回报估计将更受青睐。该损失函数的精确公式如下:
该损失函数包含两个部分:在最小值运算符的第一部分,我们仅仅计算了重要性加权版本的 REINFORCE 损失(例如,我们对当前策略配置滞后于用于数据收集的策略这一事实进行了修正)。最小值运算符的第二部分是一个类似的损失,我们在比率超过或低于给定的一对阈值时对其进行了截断。
这种损失确保了无论优势是正还是负,都会抑制那些会导致策略相对于之前配置产生显著偏移的更新。
本教程结构如下
首先,我们将定义训练中要使用的一组超参数。
接下来,我们将重点使用 TorchRL 的包装器(wrappers)和转换(transforms)来创建我们的环境(即模拟器)。
接下来,我们将设计策略网络和价值模型,这对于损失函数至关重要。这些模块将用于配置我们的损失模块。
接下来,我们将创建回放缓冲区和数据加载器。
最后,我们将运行训练循环并分析结果。
在本教程中,我们将使用 tensordict 库。TensorDict 是 TorchRL 的通用语言:它帮助我们将模块的读取和写入操作抽象化,从而减少对特定数据描述的关注,将更多精力放在算法本身上。
import warnings
warnings.filterwarnings("ignore")
from torch import multiprocessing
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter,
TransformedEnv)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tqdm import tqdm
定义超参数#
我们为算法设置超参数。根据现有资源,可以选择在 GPU 或其他设备上执行策略。frame_skip 将控制单个动作执行的帧数。其余计算帧数的参数必须根据此值进行修正(因为环境的一步实际上返回 frame_skip 帧)。
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
num_cells = 256 # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0
数据收集参数#
在收集数据时,我们可以通过定义 frames_per_batch 参数来选择每个批次的大小。我们还将定义允许使用的帧数(即与模拟器的交互次数)。通常,强化学习算法的目标是尽快通过环境交互解决任务:total_frames 越小越好。
frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 50_000
PPO 参数#
在每次数据收集(或批次收集)时,我们将在嵌套训练循环中对一定数量的“epoch”执行优化,每次消耗我们刚刚获取的所有数据。这里,sub_batch_size 与上面的 frames_per_batch 不同:请记住,我们处理的是来自收集器的“数据批次”,其大小由 frames_per_batch 定义,并且在内部训练循环中我们将进一步将其拆分为更小的子批次。这些子批次的大小由 sub_batch_size 控制。
sub_batch_size = 64 # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10 # optimization steps per batch of data collected
clip_epsilon = (
0.2 # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4
定义环境#
在强化学习中,“环境”通常指的是模拟器或控制系统。各种库为强化学习提供了模拟环境,包括 Gymnasium(前身为 OpenAI Gym)、DeepMind 控制套件等。作为通用库,TorchRL 的目标是为广泛的强化学习模拟器提供可互换的接口,使您可以轻松地在不同环境间切换。例如,创建一个包装后的 Gym 环境只需几行代码:
Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.org.cn/introduction/migration_guide/ for additional information.
此代码有几点需要注意:首先,我们通过调用 GymEnv 包装器创建了环境。如果传递了额外的关键字参数,它们将被传递给 gym.make 方法,从而涵盖最常见的环境构建命令。或者,也可以直接使用 gym.make(env_name, **kwargs) 创建 gym 环境,并将其包装在 GymWrapper 类中。
还有 device 参数:对于 Gym,这仅控制输入动作和观测状态的存储设备,但执行始终会在 CPU 上完成。其原因很简单,除非另有说明,否则 Gym 不支持设备端执行。对于其他库,我们可以控制执行设备,并且尽可能在存储和执行后端方面保持一致。
转换#
我们将向环境添加一些转换(transforms),为策略准备数据。在 Gym 中,这通常通过包装器实现。TorchRL 采取了一种不同的方法,通过转换(transforms)实现,更类似于其他 PyTorch 领域库。要向环境添加转换,只需将其包装在 TransformedEnv 实例中,并将一系列转换追加到其中。转换后的环境将继承包装环境的设备和元数据,并根据其包含的转换序列对数据进行处理。
归一化#
首先需要编码的是归一化转换。根据经验,最好让数据大致符合单位高斯分布:为此,我们将通过在环境中运行一定数量的随机步骤来计算这些观测值的汇总统计信息。
我们将添加另外两个转换:DoubleToFloat 转换将双精度条目转换为单精度数字,以便策略读取。StepCounter 转换将用于计算环境终止前的步数。我们将使用该指标作为性能的补充衡量标准。
稍后我们将看到,许多 TorchRL 类依赖 TensorDict 进行通信。您可以将其视为带有额外张量功能的 Python 字典。在实践中,这意味着我们工作的许多模块需要被告知在接收到的 tensordict 中读取哪些键(in_keys)以及写入哪些键(out_keys)。通常,如果省略 out_keys,则假定 in_keys 条目将被就地(in-place)更新。对于我们的转换,我们感兴趣的唯一条目被称为 "observation",我们的转换层将被指示仅修改此条目。
env = TransformedEnv(
base_env,
Compose(
# normalize observations
ObservationNorm(in_keys=["observation"]),
DoubleToFloat(),
StepCounter(),
),
)
正如您可能已经注意到的,我们创建了一个归一化层,但没有设置其归一化参数。为此,ObservationNorm 可以自动收集我们环境的汇总统计信息。
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
ObservationNorm 转换现在已经填充了位置和缩放参数,这些参数将用于归一化数据。
让我们为汇总统计信息的形状做一个小的完整性检查。
print("normalization constant shape:", env.transform[0].loc.shape)
normalization constant shape: torch.Size([11])
环境不仅由其模拟器和转换定义,还由描述执行期间预期情况的一系列元数据定义。出于效率考虑,TorchRL 在环境规格方面非常严格,但您可以轻松检查您的环境规格是否合适。在我们的示例中,继承自 GymWrapper 的 GymEnv 已经为您处理了设置正确的环境规格,因此您无需担心这一点。
不过,让我们通过查看变换后的环境规格来看看具体的例子。有三个需要注意的规格:定义在环境中执行动作时预期的 observation_spec,指示奖励域的 reward_spec,以及包含 action_spec 并代表环境执行单个步骤所需一切的 input_spec。
print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("input_spec:", env.input_spec)
print("action_spec (as defined by input_spec):", env.action_spec)
observation_spec: Composite(
observation: UnboundedContinuous(
shape=torch.Size([11]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous),
step_count: BoundedDiscrete(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
device=cpu,
dtype=torch.int64,
domain=discrete),
device=cpu,
shape=torch.Size([]),
data_cls=None)
reward_spec: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
input_spec: Composite(
full_state_spec: Composite(
step_count: BoundedDiscrete(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
device=cpu,
dtype=torch.int64,
domain=discrete),
device=cpu,
shape=torch.Size([]),
data_cls=None),
full_action_spec: Composite(
action: BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous),
device=cpu,
shape=torch.Size([]),
data_cls=None),
device=cpu,
shape=torch.Size([]),
data_cls=None)
action_spec (as defined by input_spec): BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
该 check_env_specs() 函数运行一个小的回滚,并将其输出与环境规范进行比较。如果没有引发错误,我们可以确信规范已正确定义。
2026-06-03 00:35:03,116 [torchrl][INFO] check_env_specs succeeded! [END]
为了好玩,让我们看看简单的随机滚动(rollout)是什么样的。您可以调用 env.rollout(n_steps) 并获得环境输入和输出的概览。动作将自动从动作规格域中抽取,因此您无需担心设计随机采样器。
通常,在每一步中,强化学习环境接收动作作为输入,并输出观测值、奖励和完成状态。观测值可能是复合的,意味着它可能由多个张量组成。这对 TorchRL 来说不是问题,因为整套观测值会自动打包在输出的 TensorDict 中。在执行一定步数的滚动(例如,一系列环境步骤和随机动作生成)后,我们将检索到一个形状匹配该轨迹长度的 TensorDict 实例。
rollout = env.rollout(3)
print("rollout of three steps:", rollout)
print("Shape of the rollout TensorDict:", rollout.batch_size)
rollout of three steps: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False)
Shape of the rollout TensorDict: torch.Size([3])
我们的滚动数据形状为 torch.Size([3]),这与我们运行的步数相匹配。"next" 条目指向当前步骤之后的数据。在大多数情况下,时间 t 处的 "next" 数据与 t+1 处的数据匹配,但如果我们使用某些特定的转换(例如,多步转换),情况可能并非如此。
策略 (Policy)#
PPO 利用随机策略来处理探索。这意味着我们的神经网络必须输出分布的参数,而不是对应于所采取动作的单个值。
由于数据是连续的,我们使用 Tanh-Normal 分布来遵守动作空间边界。TorchRL 提供了这样的分布,我们唯一需要关心的是构建一个输出正确参数数量的神经网络,以便策略能够工作(位置或均值,以及缩放)。
这里带来的唯一额外困难是将我们的输出分成两等份,并将第二部分映射到严格正的空间。
我们分三步设计策略:
定义一个神经网络
D_obs->2 * D_action。事实上,我们的loc(mu) 和scale(sigma) 的维度均为D_action。追加一个
NormalParamExtractor来提取位置和缩放(例如,将输入分成相等的两部分并对缩放参数应用正向转换)。创建一个概率
TensorDictModule,它可以生成此分布并从中采样。
actor_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
NormalParamExtractor(),
)
为了使策略能够通过 tensordict 数据载体与环境“交流”,我们将 nn.Module 包装在 TensorDictModule 中。该类将读取提供的 in_keys,并将输出就地写入注册的 out_keys。
policy_module = TensorDictModule(
actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)
我们现在需要用正态分布的位置和缩放构建一个分布。为此,我们指示 ProbabilisticActor 类从位置和缩放参数构建一个 TanhNormal。我们还提供了此分布的最小值和最大值,这些值是从环境规格中收集的。
in_keys(以及上述 TensorDictModule 中的 out_keys)的名称不能随意设置,因为 TanhNormal 分布构造函数需要 loc 和 scale 关键字参数。话虽如此,ProbabilisticActor 也接受 Dict[str, str] 类型的 in_keys,其中键值对指示应为每个要使用的关键字参数使用哪个 in_key 字符串。
policy_module = ProbabilisticActor(
module=policy_module,
spec=env.action_spec,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
},
return_log_prob=True,
# we'll need the log-prob for the numerator of the importance weights
)
价值网络#
价值网络是 PPO 算法的关键组件,尽管它不会在推理时使用。该模块将读取观测值并返回对后续轨迹的折现回报的估计。这使我们能够通过依赖在训练期间即时学习的效用估计来分摊学习成本。我们的价值网络与策略共享相同的结构,但为了简单起见,我们为其分配了一组独立的参数。
value_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(1, device=device),
)
value_module = ValueOperator(
module=value_net,
in_keys=["observation"],
)
让我们尝试一下我们的策略和价值模块。正如我们前面所说,使用 TensorDictModule 使得可以直接读取环境的输出并在这些模块上运行,因为它们知道要读取什么信息以及写入哪里。
print("Running policy:", policy_module(env.reset()))
print("Running value:", value_module(env.reset()))
Running policy: TensorDict(
fields={
action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
action_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
loc: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=cpu,
is_shared=False)
Running value: TensorDict(
fields={
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
state_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=cpu,
is_shared=False)
数据收集器#
TorchRL 提供了一系列 数据收集器类。简而言之,这些类执行三个操作:重置环境、根据最新观测值计算动作、在环境中执行步骤,并重复后两个步骤直到环境发出停止信号(或达到完成状态)。
它们允许您控制每次迭代收集多少帧(通过 frames_per_batch 参数)、何时重置环境(通过 max_frames_per_traj 参数)、策略应在哪个 device 上执行等。它们也被设计为能与批处理和多进程环境高效协同工作。
最简单的数据收集器是 SyncDataCollector:它是一个迭代器,您可以用来获取给定长度的数据批次,并会在收集到总帧数(total_frames)后停止。其他数据收集器(MultiSyncDataCollector 和 MultiaSyncDataCollector)将以同步或异步方式在一组多进程工作器上执行相同的操作。
与之前的策略和环境一样,数据收集器将返回 TensorDict 实例,其中总元素数量将与 frames_per_batch 匹配。使用 TensorDict 将数据传递给训练循环,使您能够编写完全不关心滚动内容实际特性的数据加载流水线。
collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=False,
device=device,
)
回放缓冲区#
回放缓冲区是离策略 RL 算法的常见构建模块。在策略环境中,每当收集一批数据时,回放缓冲区就会被重新填充,并且其数据会在一定数量的 epoch 中被重复消耗。
TorchRL 的回放缓冲区是使用通用容器 ReplayBuffer 构建的,该容器将缓冲区的组件作为参数:存储、写入器、采样器以及可能的转换。只有存储(指示回放缓冲区容量)是强制性的。我们还指定了一个无放回采样器,以避免在一个 epoch 中多次采样同一项目。对于 PPO 来说,使用回放缓冲区不是强制性的,我们可以简单地从收集的批次中采样子批次,但使用这些类可以使我们以可重现的方式轻松构建内部训练循环。
replay_buffer = ReplayBuffer(
storage=LazyTensorStorage(max_size=frames_per_batch),
sampler=SamplerWithoutReplacement(),
)
损失函数#
为了方便起见,PPO 损失可以直接使用 ClipPPOLoss 类从 TorchRL 导入。这是使用 PPO 最简单的方法:它隐藏了 PPO 的数学运算及其随之而来的控制流。
PPO 需要计算一些“优势估计”。简而言之,优势是一个值,它在处理偏差/方差权衡的同时反映了对回报值的期望。要计算优势,只需(1)构建利用我们价值算子的优势模块,以及(2)在每个 epoch 之前将每批数据通过它。GAE 模块将使用新的 "advantage" 和 "value_target" 条目更新输入的 tensordict。"value_target" 是一个无梯度的张量,代表价值网络应通过输入观测值表示的经验值。两者都将由 ClipPPOLoss 使用,以返回策略和价值损失。
advantage_module = GAE(
gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True, device=device,
)
loss_module = ClipPPOLoss(
actor_network=policy_module,
critic_network=value_module,
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
entropy_coef=entropy_eps,
# these keys match by default but we set this for completeness
critic_coef=1.0,
loss_critic_type="smooth_l1",
)
optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optim, total_frames // frames_per_batch, 0.0
)
/usr/local/lib/python3.10/dist-packages/torchrl/objectives/ppo.py:445: DeprecationWarning: 'critic_coef' is deprecated and will be removed in torchrl v0.11. Please use 'critic_coeff' instead.
warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchrl/objectives/ppo.py:511: DeprecationWarning: 'entropy_coef' is deprecated and will be removed in torchrl v0.11. Please use 'entropy_coeff' instead.
warnings.warn(
训练循环#
现在我们有了编写训练循环所需的所有组件。步骤包括
收集数据
计算优势
循环遍历收集的数据以计算损失值
反向传播
优化
重复
重复
重复
logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""
# We iterate over the collector until it reaches the total number of frames it was
# designed to collect:
for i, tensordict_data in enumerate(collector):
# we now have a batch of data to work with. Let's learn something from it.
for _ in range(num_epochs):
# We'll need an "advantage" signal to make PPO work.
# We re-compute it at each epoch as its value depends on the value
# network which is updated in the inner loop.
advantage_module(tensordict_data)
data_view = tensordict_data.reshape(-1)
replay_buffer.extend(data_view.cpu())
for _ in range(frames_per_batch // sub_batch_size):
subdata = replay_buffer.sample(sub_batch_size)
loss_vals = loss_module(subdata.to(device))
loss_value = (
loss_vals["loss_objective"]
+ loss_vals["loss_critic"]
+ loss_vals["loss_entropy"]
)
# Optimization: backward, grad clipping and optimization step
loss_value.backward()
# this is not strictly mandatory but it's good practice to keep
# your gradient norm bounded
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
optim.step()
optim.zero_grad()
logs["reward"].append(tensordict_data["next", "reward"].mean().item())
pbar.update(tensordict_data.numel())
cum_reward_str = (
f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
)
logs["step_count"].append(tensordict_data["step_count"].max().item())
stepcount_str = f"step count (max): {logs['step_count'][-1]}"
logs["lr"].append(optim.param_groups[0]["lr"])
lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
if i % 10 == 0:
# We evaluate the policy once every 10 batches of data.
# Evaluation is rather simple: execute the policy without exploration
# (take the expected value of the action distribution) for a given
# number of steps (1000, which is our ``env`` horizon).
# The ``rollout`` method of the ``env`` can take a policy as argument:
# it will then execute this policy at each step.
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
# execute a rollout with the trained policy
eval_rollout = env.rollout(1000, policy_module)
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
logs["eval reward (sum)"].append(
eval_rollout["next", "reward"].sum().item()
)
logs["eval step_count"].append(eval_rollout["step_count"].max().item())
eval_str = (
f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
f"eval step-count: {logs['eval step_count'][-1]}"
)
del eval_rollout
pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))
# We're also using a learning rate scheduler. Like the gradient clipping,
# this is a nice-to-have but nothing necessary for PPO to work.
scheduler.step()
0%| | 0/50000 [00:00<?, ?it/s]
2%|▏ | 1000/50000 [00:03<02:33, 318.64it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.0884 (init= 9.0884), step count (max): 13, lr policy: 0.0003: 2%|▏ | 1000/50000 [00:03<02:33, 318.64it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.0884 (init= 9.0884), step count (max): 13, lr policy: 0.0003: 4%|▍ | 2000/50000 [00:06<02:30, 318.72it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.1140 (init= 9.0884), step count (max): 11, lr policy: 0.0003: 4%|▍ | 2000/50000 [00:06<02:30, 318.72it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.1140 (init= 9.0884), step count (max): 11, lr policy: 0.0003: 6%|▌ | 3000/50000 [00:09<02:23, 327.70it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.1447 (init= 9.0884), step count (max): 14, lr policy: 0.0003: 6%|▌ | 3000/50000 [00:09<02:23, 327.70it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.1447 (init= 9.0884), step count (max): 14, lr policy: 0.0003: 8%|▊ | 4000/50000 [00:12<02:18, 333.04it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.1873 (init= 9.0884), step count (max): 19, lr policy: 0.0003: 8%|▊ | 4000/50000 [00:12<02:18, 333.04it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.1873 (init= 9.0884), step count (max): 19, lr policy: 0.0003: 10%|█ | 5000/50000 [00:15<02:13, 337.07it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.1913 (init= 9.0884), step count (max): 29, lr policy: 0.0003: 10%|█ | 5000/50000 [00:15<02:13, 337.07it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.1913 (init= 9.0884), step count (max): 29, lr policy: 0.0003: 12%|█▏ | 6000/50000 [00:17<02:09, 339.29it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.2133 (init= 9.0884), step count (max): 25, lr policy: 0.0003: 12%|█▏ | 6000/50000 [00:17<02:09, 339.29it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.2133 (init= 9.0884), step count (max): 25, lr policy: 0.0003: 14%|█▍ | 7000/50000 [00:20<02:05, 341.83it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.2289 (init= 9.0884), step count (max): 31, lr policy: 0.0003: 14%|█▍ | 7000/50000 [00:20<02:05, 341.83it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.2289 (init= 9.0884), step count (max): 31, lr policy: 0.0003: 16%|█▌ | 8000/50000 [00:23<02:02, 343.43it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.2284 (init= 9.0884), step count (max): 25, lr policy: 0.0003: 16%|█▌ | 8000/50000 [00:23<02:02, 343.43it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.2284 (init= 9.0884), step count (max): 25, lr policy: 0.0003: 18%|█▊ | 9000/50000 [00:26<01:58, 345.21it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.2448 (init= 9.0884), step count (max): 37, lr policy: 0.0003: 18%|█▊ | 9000/50000 [00:26<01:58, 345.21it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.2448 (init= 9.0884), step count (max): 37, lr policy: 0.0003: 20%|██ | 10000/50000 [00:29<01:57, 340.38it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.2404 (init= 9.0884), step count (max): 35, lr policy: 0.0003: 20%|██ | 10000/50000 [00:29<01:57, 340.38it/s]
eval cumulative reward: 82.2209 (init: 82.2209), eval step-count: 8, average reward= 9.2404 (init= 9.0884), step count (max): 35, lr policy: 0.0003: 22%|██▏ | 11000/50000 [00:32<01:53, 343.46it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2546 (init= 9.0884), step count (max): 44, lr policy: 0.0003: 22%|██▏ | 11000/50000 [00:32<01:53, 343.46it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2546 (init= 9.0884), step count (max): 44, lr policy: 0.0003: 24%|██▍ | 12000/50000 [00:35<01:50, 344.42it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2572 (init= 9.0884), step count (max): 46, lr policy: 0.0003: 24%|██▍ | 12000/50000 [00:35<01:50, 344.42it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2572 (init= 9.0884), step count (max): 46, lr policy: 0.0003: 26%|██▌ | 13000/50000 [00:38<01:46, 346.64it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2736 (init= 9.0884), step count (max): 69, lr policy: 0.0003: 26%|██▌ | 13000/50000 [00:38<01:46, 346.64it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2736 (init= 9.0884), step count (max): 69, lr policy: 0.0003: 28%|██▊ | 14000/50000 [00:41<01:43, 347.69it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2774 (init= 9.0884), step count (max): 76, lr policy: 0.0003: 28%|██▊ | 14000/50000 [00:41<01:43, 347.69it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2774 (init= 9.0884), step count (max): 76, lr policy: 0.0003: 30%|███ | 15000/50000 [00:43<01:40, 349.14it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2785 (init= 9.0884), step count (max): 58, lr policy: 0.0002: 30%|███ | 15000/50000 [00:43<01:40, 349.14it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2785 (init= 9.0884), step count (max): 58, lr policy: 0.0002: 32%|███▏ | 16000/50000 [00:46<01:37, 350.11it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2806 (init= 9.0884), step count (max): 47, lr policy: 0.0002: 32%|███▏ | 16000/50000 [00:46<01:37, 350.11it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2806 (init= 9.0884), step count (max): 47, lr policy: 0.0002: 34%|███▍ | 17000/50000 [00:49<01:35, 345.33it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.3005 (init= 9.0884), step count (max): 68, lr policy: 0.0002: 34%|███▍ | 17000/50000 [00:49<01:35, 345.33it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.3005 (init= 9.0884), step count (max): 68, lr policy: 0.0002: 36%|███▌ | 18000/50000 [00:52<01:32, 347.67it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2830 (init= 9.0884), step count (max): 61, lr policy: 0.0002: 36%|███▌ | 18000/50000 [00:52<01:32, 347.67it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2830 (init= 9.0884), step count (max): 61, lr policy: 0.0002: 38%|███▊ | 19000/50000 [00:55<01:28, 349.15it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2819 (init= 9.0884), step count (max): 71, lr policy: 0.0002: 38%|███▊ | 19000/50000 [00:55<01:28, 349.15it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2819 (init= 9.0884), step count (max): 71, lr policy: 0.0002: 40%|████ | 20000/50000 [00:58<01:25, 350.03it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2923 (init= 9.0884), step count (max): 86, lr policy: 0.0002: 40%|████ | 20000/50000 [00:58<01:25, 350.03it/s]
eval cumulative reward: 203.5179 (init: 82.2209), eval step-count: 21, average reward= 9.2923 (init= 9.0884), step count (max): 86, lr policy: 0.0002: 42%|████▏ | 21000/50000 [01:01<01:22, 350.98it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2906 (init= 9.0884), step count (max): 84, lr policy: 0.0002: 42%|████▏ | 21000/50000 [01:01<01:22, 350.98it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2906 (init= 9.0884), step count (max): 84, lr policy: 0.0002: 44%|████▍ | 22000/50000 [01:03<01:20, 348.11it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2967 (init= 9.0884), step count (max): 81, lr policy: 0.0002: 44%|████▍ | 22000/50000 [01:03<01:20, 348.11it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2967 (init= 9.0884), step count (max): 81, lr policy: 0.0002: 46%|████▌ | 23000/50000 [01:06<01:17, 349.52it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.3004 (init= 9.0884), step count (max): 98, lr policy: 0.0002: 46%|████▌ | 23000/50000 [01:06<01:17, 349.52it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.3004 (init= 9.0884), step count (max): 98, lr policy: 0.0002: 48%|████▊ | 24000/50000 [01:09<01:14, 349.57it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2927 (init= 9.0884), step count (max): 61, lr policy: 0.0002: 48%|████▊ | 24000/50000 [01:09<01:14, 349.57it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2927 (init= 9.0884), step count (max): 61, lr policy: 0.0002: 50%|█████ | 25000/50000 [01:12<01:12, 343.90it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2918 (init= 9.0884), step count (max): 101, lr policy: 0.0002: 50%|█████ | 25000/50000 [01:12<01:12, 343.90it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2918 (init= 9.0884), step count (max): 101, lr policy: 0.0002: 52%|█████▏ | 26000/50000 [01:15<01:09, 346.59it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2915 (init= 9.0884), step count (max): 79, lr policy: 0.0001: 52%|█████▏ | 26000/50000 [01:15<01:09, 346.59it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2915 (init= 9.0884), step count (max): 79, lr policy: 0.0001: 54%|█████▍ | 27000/50000 [01:18<01:06, 347.85it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2822 (init= 9.0884), step count (max): 59, lr policy: 0.0001: 54%|█████▍ | 27000/50000 [01:18<01:06, 347.85it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2822 (init= 9.0884), step count (max): 59, lr policy: 0.0001: 56%|█████▌ | 28000/50000 [01:21<01:03, 348.77it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2984 (init= 9.0884), step count (max): 84, lr policy: 0.0001: 56%|█████▌ | 28000/50000 [01:21<01:03, 348.77it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2984 (init= 9.0884), step count (max): 84, lr policy: 0.0001: 58%|█████▊ | 29000/50000 [01:24<01:00, 349.93it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2883 (init= 9.0884), step count (max): 65, lr policy: 0.0001: 58%|█████▊ | 29000/50000 [01:24<01:00, 349.93it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2883 (init= 9.0884), step count (max): 65, lr policy: 0.0001: 60%|██████ | 30000/50000 [01:26<00:57, 350.77it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2851 (init= 9.0884), step count (max): 75, lr policy: 0.0001: 60%|██████ | 30000/50000 [01:26<00:57, 350.77it/s]
eval cumulative reward: 550.3049 (init: 82.2209), eval step-count: 58, average reward= 9.2851 (init= 9.0884), step count (max): 75, lr policy: 0.0001: 62%|██████▏ | 31000/50000 [01:29<00:54, 351.19it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2847 (init= 9.0884), step count (max): 70, lr policy: 0.0001: 62%|██████▏ | 31000/50000 [01:29<00:54, 351.19it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2847 (init= 9.0884), step count (max): 70, lr policy: 0.0001: 64%|██████▍ | 32000/50000 [01:32<00:52, 341.92it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2906 (init= 9.0884), step count (max): 86, lr policy: 0.0001: 64%|██████▍ | 32000/50000 [01:32<00:52, 341.92it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2906 (init= 9.0884), step count (max): 86, lr policy: 0.0001: 66%|██████▌ | 33000/50000 [01:35<00:49, 345.11it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2892 (init= 9.0884), step count (max): 81, lr policy: 0.0001: 66%|██████▌ | 33000/50000 [01:35<00:49, 345.11it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2892 (init= 9.0884), step count (max): 81, lr policy: 0.0001: 68%|██████▊ | 34000/50000 [01:38<00:46, 347.54it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2997 (init= 9.0884), step count (max): 75, lr policy: 0.0001: 68%|██████▊ | 34000/50000 [01:38<00:46, 347.54it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2997 (init= 9.0884), step count (max): 75, lr policy: 0.0001: 70%|███████ | 35000/50000 [01:41<00:43, 348.32it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.3058 (init= 9.0884), step count (max): 100, lr policy: 0.0001: 70%|███████ | 35000/50000 [01:41<00:43, 348.32it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.3058 (init= 9.0884), step count (max): 100, lr policy: 0.0001: 72%|███████▏ | 36000/50000 [01:44<00:40, 349.92it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.3044 (init= 9.0884), step count (max): 100, lr policy: 0.0001: 72%|███████▏ | 36000/50000 [01:44<00:40, 349.92it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.3044 (init= 9.0884), step count (max): 100, lr policy: 0.0001: 74%|███████▍ | 37000/50000 [01:47<00:36, 351.39it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.3062 (init= 9.0884), step count (max): 105, lr policy: 0.0001: 74%|███████▍ | 37000/50000 [01:47<00:36, 351.39it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.3062 (init= 9.0884), step count (max): 105, lr policy: 0.0001: 76%|███████▌ | 38000/50000 [01:49<00:34, 351.80it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2877 (init= 9.0884), step count (max): 68, lr policy: 0.0000: 76%|███████▌ | 38000/50000 [01:49<00:34, 351.80it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2877 (init= 9.0884), step count (max): 68, lr policy: 0.0000: 78%|███████▊ | 39000/50000 [01:52<00:31, 346.27it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.3085 (init= 9.0884), step count (max): 89, lr policy: 0.0000: 78%|███████▊ | 39000/50000 [01:52<00:31, 346.27it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.3085 (init= 9.0884), step count (max): 89, lr policy: 0.0000: 80%|████████ | 40000/50000 [01:55<00:28, 348.63it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2986 (init= 9.0884), step count (max): 76, lr policy: 0.0000: 80%|████████ | 40000/50000 [01:55<00:28, 348.63it/s]
eval cumulative reward: 512.7717 (init: 82.2209), eval step-count: 54, average reward= 9.2986 (init= 9.0884), step count (max): 76, lr policy: 0.0000: 82%|████████▏ | 41000/50000 [01:58<00:25, 350.20it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3021 (init= 9.0884), step count (max): 101, lr policy: 0.0000: 82%|████████▏ | 41000/50000 [01:58<00:25, 350.20it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3021 (init= 9.0884), step count (max): 101, lr policy: 0.0000: 84%|████████▍ | 42000/50000 [02:01<00:22, 348.22it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3101 (init= 9.0884), step count (max): 102, lr policy: 0.0000: 84%|████████▍ | 42000/50000 [02:01<00:22, 348.22it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3101 (init= 9.0884), step count (max): 102, lr policy: 0.0000: 86%|████████▌ | 43000/50000 [02:04<00:20, 349.91it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3019 (init= 9.0884), step count (max): 76, lr policy: 0.0000: 86%|████████▌ | 43000/50000 [02:04<00:20, 349.91it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3019 (init= 9.0884), step count (max): 76, lr policy: 0.0000: 88%|████████▊ | 44000/50000 [02:07<00:17, 351.24it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3056 (init= 9.0884), step count (max): 103, lr policy: 0.0000: 88%|████████▊ | 44000/50000 [02:07<00:17, 351.24it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3056 (init= 9.0884), step count (max): 103, lr policy: 0.0000: 90%|█████████ | 45000/50000 [02:09<00:14, 351.09it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.2996 (init= 9.0884), step count (max): 83, lr policy: 0.0000: 90%|█████████ | 45000/50000 [02:09<00:14, 351.09it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.2996 (init= 9.0884), step count (max): 83, lr policy: 0.0000: 92%|█████████▏| 46000/50000 [02:12<00:11, 351.84it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3039 (init= 9.0884), step count (max): 78, lr policy: 0.0000: 92%|█████████▏| 46000/50000 [02:12<00:11, 351.84it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3039 (init= 9.0884), step count (max): 78, lr policy: 0.0000: 94%|█████████▍| 47000/50000 [02:15<00:08, 346.69it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3125 (init= 9.0884), step count (max): 104, lr policy: 0.0000: 94%|█████████▍| 47000/50000 [02:15<00:08, 346.69it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3125 (init= 9.0884), step count (max): 104, lr policy: 0.0000: 96%|█████████▌| 48000/50000 [02:18<00:05, 349.01it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3079 (init= 9.0884), step count (max): 122, lr policy: 0.0000: 96%|█████████▌| 48000/50000 [02:18<00:05, 349.01it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3079 (init= 9.0884), step count (max): 122, lr policy: 0.0000: 98%|█████████▊| 49000/50000 [02:21<00:02, 350.66it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3063 (init= 9.0884), step count (max): 82, lr policy: 0.0000: 98%|█████████▊| 49000/50000 [02:21<00:02, 350.66it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3063 (init= 9.0884), step count (max): 82, lr policy: 0.0000: 100%|██████████| 50000/50000 [02:24<00:00, 351.37it/s]
eval cumulative reward: 465.9574 (init: 82.2209), eval step-count: 49, average reward= 9.3010 (init= 9.0884), step count (max): 72, lr policy: 0.0000: 100%|██████████| 50000/50000 [02:24<00:00, 351.37it/s]
结果#
在达到 1M 步的上限之前,算法应该已经达到了 1000 步的最大步数计数,这是轨迹被截断之前的最大步数。
plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()

结论与后续步骤#
在本教程中,我们学习了:
如何使用
torchrl创建和自定义环境;如何编写模型和损失函数;
如何设置典型的训练循环。
如果您想进一步通过本教程进行实验,可以应用以下修改:
从效率角度来看,我们可以并行运行多个模拟以加快数据收集速度。查看
ParallelEnv以获取更多信息。从日志记录的角度来看,可以在请求渲染后将
torchrl.record.VideoRecorder转换添加到环境中,以获得倒立摆运行的可视化呈现。查看torchrl.record以了解更多信息。
脚本总运行时间:(2 分 26.008 秒)