注意
转到页面末尾 下载完整示例代码。
从检查点加载 nn.Module 的技巧#
创建日期:2023 年 10 月 3 日 | 最后更新:2024 年 8 月 27 日 | 最后验证:2024 年 11 月 5 日
如果您正在加载检查点并希望尽可能减少计算和内存占用,本教程将分享一些推荐的实践方法。我们将重点讨论:
torch.load中的mmap关键字参数torch.device()上下文管理器nn.Module.load_state_dict()中的assign关键字参数
注意
本教程需要 PyTorch 2.1.0 或更高版本。
让我们考虑一个包含一系列线性层的简单 nn.Module
import torch
from torch import nn
import time
class SomeModule(torch.nn.Module):
def __init__(self, size):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(size, size) for i in range(10)])
def forward(self, x):
return self.linears(x)
m = SomeModule(1000)
torch.save(m.state_dict(), 'checkpoint.pth')
以下代码片段演示了如何使用 torch.load 的 mmap 关键字参数、torch.device() 上下文管理器以及 nn.Module.load_state_dict() 的 assign 关键字参数。
state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
with torch.device('meta'):
meta_m = SomeModule(1000)
meta_m.load_state_dict(state_dict, assign=True)
<All keys matched successfully>
将下方的代码片段与上面的进行比较
state_dict = torch.load('checkpoint.pth', weights_only=True)
m = SomeModule(1000)
m.load_state_dict(state_dict)
<All keys matched successfully>
第二个示例没有使用上述任何功能,因此在加载检查点时的计算和内存效率会较低。在接下来的章节中,我们将更详细地讨论每个功能。
使用 torch.load(mmap=True)#
首先,让我们考虑使用 torch.load 加载检查点时会发生什么。当我们使用 torch.save 保存检查点时,张量存储会被标记为它们保存时所在的设备。使用 torch.load 时,张量存储会被加载到它们被标记的设备上(除非使用 map_location 标志覆盖此行为)。为了便于解释,假设这些张量是保存在 CPU 上的。这意味着在执行第一行代码时,所有张量存储都会被加载到 CPU 内存中,当出现以下情况时,这可能是不理想的:
CPU 内存小于检查点文件的大小。
在执行例如逐张量处理之前,需要等待整个检查点加载到内存中。
start_time = time.time()
state_dict = torch.load('checkpoint.pth', weights_only=True)
end_time = time.time()
print(f"loading time without mmap={end_time - start_time}")
loading time without mmap=0.034053802490234375
torch.load 的 mmap 关键字参数试图解决上述两个问题。顾名思义,torch.load 的 mmap 参数利用了 mmap 调用,该调用将磁盘上的文件映射到虚拟内存中,并让操作系统自动处理加载和卸载物理内存的过程。当传入此标志时,张量存储将被内存映射。
start_time = time.time()
state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
end_time = time.time()
print(f"loading time with mmap={end_time - start_time}")
loading time with mmap=0.0027823448181152344
如上所述,可以使用此参数对检查点进行逐张量处理,而无需预先将所有张量存储加载到 CPU 内存中。例如:
def my_special_routine(t, device):
# this could be a much fancier operation
return t.to(dtype=torch.bfloat16, device=device)
def my_processing_function(key, device):
t = state_dict[key]
processed_t = my_special_routine(t, device)
del t
state_dict[key] = processed_t
for key in state_dict.keys():
device = torch.device('cuda')
my_processing_function(key, device)
使用 torch.device('meta')#
接下来,让我们考虑模块的创建。
m = SomeModule(1000)
这会为所有参数/缓冲区分配内存,并根据 SomeModule.__init__() 中定义的默认初始化方案对它们进行初始化。当我们想从检查点加载时,这样做是浪费的,原因如下:
初始化内核的结果会被
load_state_dict()覆盖而从未使用,因此初始化是浪费的。我们在内存中为这些参数/缓冲区分配了内存,而保存的 state dictionary 的
torch.load也会在内存中为检查点中的参数/缓冲区分配内存。
为了解决这两个问题,我们可以在实例化 nn.Module() 时使用带有 device='meta' 的 torch.device() 上下文管理器。
torch.device() 上下文管理器确保工厂调用执行时,就像它们被传入了指定的 device 参数一样。torch.device('meta') 上的张量不携带数据。但是,它们拥有张量所携带的所有其他元数据,例如 .size()、.stride()、.requires_grad 等。
with torch.device('meta'):
new_m = SomeModule(1000)
使用 load_state_dict(assign=True)#
接下来,我们考虑 state dictionary 的加载。
m.load_state_dict(state_dict)
<All keys matched successfully>
nn.Module.load_state_dict() 通常通过原地操作 param_in_model.copy_(param_in_state_dict) 来实现。这意味着 state dictionary 中具有相应键的参数/缓冲区会被复制到 nn.Module 的参数/缓冲区中。
然而,向 meta 设备上的张量进行原地复制是一个空操作(no-op)。为了避免这种情况,我们可以向 load_state_dict() 传递 assign=True 关键字参数。
这里需要注意的一个问题是,由于优化器持有对 nn.Module.parameters() 的引用,如果传入了 assign=True,则优化器必须在模型从 state dict 加载后进行初始化。
# As of PyTorch 2.3.0, one can use ``torch.__future__.set_swap_module_params_on_conversion`` to
# avoid this caveat. This `recipe <https://pytorch.ac.cn/tutorials/recipes/recipes/swap_tensors.html>`_
# provides more details.
new_m.load_state_dict(state_dict, assign=True)
# Before 2.3.0, this MUST be done AFTER the load_state_dict with assign.
# In versions >= 2.3.0, one can consider setting ``torch.__future__.set_swap_module_params_on_conversion``
opt = torch.optim.SGD(new_m.parameters(), lr=1e-3)
结论#
总结一下,在本教程中,我们学习了 torch.load(mmap=True)、带有 device=meta 的 torch.device() 上下文管理器,以及 nn.Module.load_state_dict(assign=True),以及当从检查点加载模型时如何利用这些工具。
脚本总运行时间: (0 分钟 0.541 秒)