注意
跳转至页面底部下载完整示例代码。
nn.Module 中用于 load_state_dict 和张量子类的扩展点#
创建日期:2024年4月19日 | 最后更新:2024年4月19日 | 最后验证:未验证
本教程介绍了一个新的实用函数 torch.utils.swap_tensors,以及它在 nn.Module 中集成的两个新扩展点:
nn.Module.to()及其相关方法nn.Module.load_state_dict()
注意
本教程要求 PyTorch 2.3.0 或更高版本。
torch.utils.swap_tensors#
torch.utils.swap_tensors(以下简称 swap_tensors)是一个接收两个 Python 张量并交换它们的实用函数。
import torch
import torch.nn as nn
t1 = torch.arange(2)
t2 = torch.arange(3)
print(f"Before swapping, t1: {t1}, t2: {t2}")
torch.utils.swap_tensors(t1, t2)
print(f"After swapping, t1: {t1}, t2: {t2}")
Before swapping, t1: tensor([0, 1]), t2: tensor([0, 1, 2])
After swapping, t1: tensor([0, 1, 2]), t2: tensor([0, 1])
更具体地说,swap_tensors 会交换两个张量的 Python __class__、__dict__ 和 __slots__,以及它们关联的 at::Tensor。
在 nn.Module 中的应用#
当模块之外的 Python 对象持有模块参数的引用时,此工具对 nn.Module 非常有用。如果 nn.Module 以非原地(out-of-place)方式修改其任何参数,持有该参数引用的对象将无法感知到此变化。一个典型的例子是优化器(optimizer),它持有 nn.Module 参数的引用。这会导致一个隐蔽的正确性问题:optimizer.step() 将正常运行且不报错,但 nn.Module 的权重却未得到更新。
mod = torch.nn.Linear(1, 2, bias=False)
optimizer = torch.optim.SGD(mod.parameters())
print(f"weight in mod: {mod.weight}")
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
mod.weight = torch.nn.Parameter(2 * mod.weight)
print(f"weight in mod: {mod.weight}")
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
weight in mod: Parameter containing:
tensor([[ 0.1708],
[-0.2648]], requires_grad=True)
weight in optimizer: [Parameter containing:
tensor([[ 0.1708],
[-0.2648]], requires_grad=True)]
weight in mod: Parameter containing:
tensor([[ 0.3416],
[-0.5296]], requires_grad=True)
weight in optimizer: [Parameter containing:
tensor([[ 0.1708],
[-0.2648]], requires_grad=True)]
nn.Module.load_state_dict()#
根据传递给 load_state_dict() 的 assign 关键字参数的值,有两种加载 state_dict 的方式:
assign=False:保留module.param的属性,仅获取state_dict['param_name']的值。assign=True:保留state_dict['param_name']的属性和值。
以前,这些分别通过原地 copy_ 和 __setattr__ 实现。在现有实现下,每种方法都有其自身的局限性——assign=False 要求 state_dict 中的参数类型必须与模块中的参数类型相同,而 assign=True 则要求任何持有该模块参数引用的对象必须在 nn.Module.load_state_dict() 之后进行初始化。
现在,我们通过向 load_state_dict() 添加 swap_tensors 路径并引入新的扩展点 torch.Tensor.module_load(self, other, assign=False) 来解决这两个限制。当通过上述 __future__ 启用 swap_tensors 路径时,我们可以使用 module_load 的 __torch_function__ 处理程序来对 state_dict 中的值应用自定义转换。该转换的结果将与模块中的参数进行交换。
在下面的示例中,我们将使用上面定义的 MyQuantizedLinearWeight 子类来演示如何在加载 state_dict 时利用这些特性将自定义量化方案应用于线性层的权重。
回想一下,如果 self 或 other(在此情况下为 param 或 state_dict[param_key])是 MyQuantizedLinearWeight 子类,则会调用 module_load 的 __torch_function__ 处理程序。
假设我们期望 state_dict 包含普通张量,而模块包含 MyQuantizedLinearWeight 参数,且我们希望将 state_dict 中的张量转换为该子类。那么我们可以按如下方式为 torch.Tensor.module_load 定义 __torch_function__ 处理程序。
@classmethod
def custom_torch_function(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if func is torch.Tensor.module_load:
dest, src = args[0], args[1]
assert type(dest) == cls and type(src) == torch.Tensor
return MyQuantizedLinearWeight(src, dest.scale)
else:
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
MyQuantizedLinearWeight.__torch_function__ = custom_torch_function
首先,让我们在 meta 设备上创建一个模型框架,以避免物化存储。我们将模块中的所有权重转换为 MyQuantizedLinearWeight 子类,同时保持偏置(bias)不变。
def fn(m):
if isinstance(m, nn.Linear):
requires_grad = m.weight.requires_grad
m.weight = torch.nn.Parameter(
MyQuantizedLinearWeight(m.weight, 0.5), requires_grad=requires_grad
)
with torch.device("meta"):
m = nn.Linear(3, 5)
m.apply(fn)
然后我们可以加载 state_dict。注意我们使用了 assign=True,因为对于偏置,我们希望保留 state_dict 中张量的属性(例如,我们不希望偏置在加载后仍处于 meta 设备上)。
torch.__future__.set_swap_module_params_on_conversion(True)
print(f"Before: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
print(f"m.state_dict() before load_state_dict():\n {m.state_dict()}")
state_dict = nn.Linear(3, 5).state_dict()
print(f"state_dict:\n {state_dict}")
m.load_state_dict(state_dict, assign=True)
print(f"After: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
print(f"m.state_dict() after load_state_dict():\n {m.state_dict()}")
Before: id(weight)=140458790734320, id(bias)=140457731527024
m.state_dict() before load_state_dict():
OrderedDict([('weight', MyQuantizedLinearWeight(tensor(..., device='meta', size=(5, 3)), scale=0.5)), ('bias', tensor(..., device='meta', size=(5,)))])
state_dict:
OrderedDict([('weight', tensor([[-0.2628, -0.0056, -0.5122],
[ 0.4156, 0.3666, 0.4726],
[ 0.0364, 0.0934, -0.2108],
[ 0.1672, 0.4529, -0.0464],
[ 0.4142, -0.4861, 0.3882]])), ('bias', tensor([ 0.3408, -0.5007, -0.0986, 0.5373, -0.5751]))])
After: id(weight)=140458790734320, id(bias)=140457731527024
m.state_dict() after load_state_dict():
OrderedDict([('weight', MyQuantizedLinearWeight(tensor([[-0.2628, -0.0056, -0.5122],
[ 0.4156, 0.3666, 0.4726],
[ 0.0364, 0.0934, -0.2108],
[ 0.1672, 0.4529, -0.0464],
[ 0.4142, -0.4861, 0.3882]]), scale=0.5)), ('bias', tensor([ 0.3408, -0.5007, -0.0986, 0.5373, -0.5751]))])
以上是一个关于如何使用 nn.Module.load_state_dict() 中新扩展点的玩具示例。人们还可以设想其他场景,例如当 state_dict 中存在张量子类而模块中是普通 nn.Parameters/张量时,或者两者都是张量子类时。根据具体用例,我们可以为 module_load 定义 __torch_function__ 处理程序以按需应用转换。
结论#
在本教程中,我们了解了 swap_tensors、在 nn.Module 中保留参数引用的重要性,以及如何使用受 torch.__future__.set_swap_module_params_on_conversion 门控的两个新扩展点。
脚本总运行时间:(0 分 0.016 秒)