评价此页

(测试版) 使用学习率调度器运行编译后的优化器#

创建日期:2024年5月21日 | 最后更新:2024年5月21日 | 最后验证:2024年11月5日

作者: Michael Lazos

优化器是训练任何深度学习模型的关键算法。在本示例中,我们将展示如何将使用 torch.compile 编译后的优化器与学习率(LR)调度器配合使用,以加速训练收敛。

注意

本教程需要 PyTorch 2.3.0 或更高版本。

模型设置#

在本示例中,我们将使用一系列简单的线性层。

import torch

# Create simple model
model = torch.nn.Sequential(
    *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
input = torch.rand(1024, device="cuda")

# run forward pass
output = model(input)

# run backward to populate the grads for our optimizer below
output.sum().backward()

设置并运行带有学习率调度器的编译优化器#

在本节中,我们将使用 Adam 优化器和 LinearLR 调度器,并创建一个辅助函数,通过 torch.compile() 对它们的 step() 调用进行封装。

注意

torch.compile 仅支持计算能力为 7.0 或更高的 CUDA 设备。

# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
    print("Exiting because torch.compile is not supported on this device.")
    import sys
    sys.exit(0)

# !!! IMPORTANT !!! Wrap the lr in a Tensor if we are pairing the
# the optimizer with an LR Scheduler.
# Without this, torch.compile will recompile as the value of the LR
# changes.
opt = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)

@torch.compile(fullgraph=False)
def fn():
    opt.step()
    sched.step()


# Warmup runs to compile the function
for _ in range(5):
    fn()
    print(opt.param_groups[0]["lr"])
tensor(0.0047)
tensor(0.0060)
tensor(0.0073)
tensor(0.0087)
tensor(0.0100)

扩展:非张量形式的学习率会发生什么?#

对于好奇的用户,我们将展示当我们不将学习率(LR)包装在张量中时,使用 torch.compile 会发生什么情况。

# No longer wrap the LR in a tensor here
opt = torch.optim.Adam(model.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)

@torch.compile(fullgraph=False)
def fn():
    opt.step()
    sched.step()

# Setup logging to view recompiles
torch._logging.set_logs(recompiles=True)

# Warmup runs to compile the function
# We will now recompile on each iteration
# as the value of the lr is mutated.
for _ in range(5):
    fn()
V0603 01:03:54.280000 33208 torch/_dynamo/guards.py:5188] [1/1] [__recompiles] Recompiling function wrapper in /usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py:514
V0603 01:03:54.280000 33208 torch/_dynamo/guards.py:5188] [1/1] [__recompiles]     triggered by the following guard failure(s):
V0603 01:03:54.280000 33208 torch/_dynamo/guards.py:5188] [1/1] [__recompiles]     - 1/0: Cache line invalidated because L['args'][0] got deallocated
V0603 01:03:54.324000 33208 torch/_dynamo/guards.py:5188] [2/1] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:214
V0603 01:03:54.324000 33208 torch/_dynamo/guards.py:5188] [2/1] [__recompiles]     triggered by the following guard failure(s):
V0603 01:03:54.324000 33208 torch/_dynamo/guards.py:5188] [2/1] [__recompiles]     - 2/0: Cache line invalidated because L['self'] got deallocated
V0603 01:03:57.437000 33208 torch/_dynamo/guards.py:5188] [2/2] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:214
V0603 01:03:57.437000 33208 torch/_dynamo/guards.py:5188] [2/2] [__recompiles]     triggered by the following guard failure(s):
V0603 01:03:57.437000 33208 torch/_dynamo/guards.py:5188] [2/2] [__recompiles]     - 2/1: self.param_groups[0]['lr'] == 0.003333333333333333  # (_dynamo/output_graph.py:3205 in remove_tensorify_specialized_graphargs)
V0603 01:03:57.437000 33208 torch/_dynamo/guards.py:5188] [2/2] [__recompiles]     - 2/0: Cache line invalidated because L['self'] got deallocated
V0603 01:03:59.862000 33208 torch/_dynamo/guards.py:5188] [2/3] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:214
V0603 01:03:59.862000 33208 torch/_dynamo/guards.py:5188] [2/3] [__recompiles]     triggered by the following guard failure(s):
V0603 01:03:59.862000 33208 torch/_dynamo/guards.py:5188] [2/3] [__recompiles]     - 2/2: self.param_groups[0]['lr'] == 0.004666666666666667  # (_dynamo/output_graph.py:3205 in remove_tensorify_specialized_graphargs)
V0603 01:03:59.862000 33208 torch/_dynamo/guards.py:5188] [2/3] [__recompiles]     - 2/1: self.param_groups[0]['lr'] == 0.003333333333333333  # (_dynamo/output_graph.py:3205 in remove_tensorify_specialized_graphargs)
V0603 01:03:59.862000 33208 torch/_dynamo/guards.py:5188] [2/3] [__recompiles]     - 2/0: Cache line invalidated because L['self'] got deallocated
V0603 01:04:02.301000 33208 torch/_dynamo/guards.py:5188] [2/4] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:214
V0603 01:04:02.301000 33208 torch/_dynamo/guards.py:5188] [2/4] [__recompiles]     triggered by the following guard failure(s):
V0603 01:04:02.301000 33208 torch/_dynamo/guards.py:5188] [2/4] [__recompiles]     - 2/3: self.param_groups[0]['lr'] == 0.006000000000000001  # (_dynamo/output_graph.py:3205 in remove_tensorify_specialized_graphargs)
V0603 01:04:02.301000 33208 torch/_dynamo/guards.py:5188] [2/4] [__recompiles]     - 2/2: self.param_groups[0]['lr'] == 0.004666666666666667  # (_dynamo/output_graph.py:3205 in remove_tensorify_specialized_graphargs)
V0603 01:04:02.301000 33208 torch/_dynamo/guards.py:5188] [2/4] [__recompiles]     - 2/1: self.param_groups[0]['lr'] == 0.003333333333333333  # (_dynamo/output_graph.py:3205 in remove_tensorify_specialized_graphargs)
V0603 01:04:02.301000 33208 torch/_dynamo/guards.py:5188] [2/4] [__recompiles]     - 2/0: Cache line invalidated because L['self'] got deallocated
V0603 01:04:04.965000 33208 torch/_dynamo/guards.py:5188] [2/5] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:214
V0603 01:04:04.965000 33208 torch/_dynamo/guards.py:5188] [2/5] [__recompiles]     triggered by the following guard failure(s):
V0603 01:04:04.965000 33208 torch/_dynamo/guards.py:5188] [2/5] [__recompiles]     - 2/4: self.param_groups[0]['lr'] == 0.007333333333333335  # (_dynamo/output_graph.py:3205 in remove_tensorify_specialized_graphargs)
V0603 01:04:04.965000 33208 torch/_dynamo/guards.py:5188] [2/5] [__recompiles]     - 2/3: self.param_groups[0]['lr'] == 0.006000000000000001  # (_dynamo/output_graph.py:3205 in remove_tensorify_specialized_graphargs)
V0603 01:04:04.965000 33208 torch/_dynamo/guards.py:5188] [2/5] [__recompiles]     - 2/2: self.param_groups[0]['lr'] == 0.004666666666666667  # (_dynamo/output_graph.py:3205 in remove_tensorify_specialized_graphargs)
V0603 01:04:04.965000 33208 torch/_dynamo/guards.py:5188] [2/5] [__recompiles]     - 2/1: self.param_groups[0]['lr'] == 0.003333333333333333  # (_dynamo/output_graph.py:3205 in remove_tensorify_specialized_graphargs)
V0603 01:04:04.965000 33208 torch/_dynamo/guards.py:5188] [2/5] [__recompiles]     - 2/0: Cache line invalidated because L['self'] got deallocated

通过此示例,我们可以看到由于 param_groups[0] 中的 lr 守卫(guard)失败,我们不得不对优化器进行了几次重新编译。

结论#

在本教程中,我们展示了如何将使用 torch.compile 编译的优化器与学习率调度器配合使用,以加速训练收敛。我们使用了一个包含一系列简单线性层的模型,并结合 Adam 优化器和 LinearLR 调度器,演示了学习率在迭代过程中的变化。

另请参阅

脚本总运行时间:(0 分 18.056 秒)