评价此页
nn.Module 用于 load_state_dict 和张量子类"

nn.Module 中用于 load_state_dict 和张量子类的扩展点#

创建日期:2024年4月19日 | 最后更新:2024年4月19日 | 最后验证:未验证

作者: Mikayla Gawarecki

本教程介绍了一个新的实用函数 torch.utils.swap_tensors,以及它在 nn.Module 中集成的两个新扩展点:

  • nn.Module.to() 及其相关方法

  • nn.Module.load_state_dict()

注意

本教程要求 PyTorch 2.3.0 或更高版本。

torch.utils.swap_tensors#

torch.utils.swap_tensors(以下简称 swap_tensors)是一个接收两个 Python 张量并交换它们的实用函数。

import torch
import torch.nn as nn
t1 = torch.arange(2)
t2 = torch.arange(3)
print(f"Before swapping, t1: {t1}, t2: {t2}")
torch.utils.swap_tensors(t1, t2)
print(f"After swapping, t1: {t1}, t2: {t2}")
Before swapping, t1: tensor([0, 1]), t2: tensor([0, 1, 2])
After swapping, t1: tensor([0, 1, 2]), t2: tensor([0, 1])

更具体地说,swap_tensors 会交换两个张量的 Python __class____dict____slots__,以及它们关联的 at::Tensor

nn.Module 中的应用#

当模块之外的 Python 对象持有模块参数的引用时,此工具对 nn.Module 非常有用。如果 nn.Module 以非原地(out-of-place)方式修改其任何参数,持有该参数引用的对象将无法感知到此变化。一个典型的例子是优化器(optimizer),它持有 nn.Module 参数的引用。这会导致一个隐蔽的正确性问题:optimizer.step() 将正常运行且不报错,但 nn.Module 的权重却未得到更新。

mod = torch.nn.Linear(1, 2, bias=False)
optimizer = torch.optim.SGD(mod.parameters())
print(f"weight in mod: {mod.weight}")
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
mod.weight = torch.nn.Parameter(2 * mod.weight)
print(f"weight in mod: {mod.weight}")
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
weight in mod: Parameter containing:
tensor([[ 0.1708],
        [-0.2648]], requires_grad=True)
weight in optimizer: [Parameter containing:
tensor([[ 0.1708],
        [-0.2648]], requires_grad=True)]
weight in mod: Parameter containing:
tensor([[ 0.3416],
        [-0.5296]], requires_grad=True)
weight in optimizer: [Parameter containing:
tensor([[ 0.1708],
        [-0.2648]], requires_grad=True)]

nn.Module.load_state_dict()#

根据传递给 load_state_dict()assign 关键字参数的值,有两种加载 state_dict 的方式:

  • assign=False:保留 module.param 的属性,仅获取 state_dict['param_name'] 的值。

  • assign=True:保留 state_dict['param_name'] 的属性和值。

以前,这些分别通过原地 copy___setattr__ 实现。在现有实现下,每种方法都有其自身的局限性——assign=False 要求 state_dict 中的参数类型必须与模块中的参数类型相同,而 assign=True 则要求任何持有该模块参数引用的对象必须在 nn.Module.load_state_dict() 之后进行初始化。

现在,我们通过向 load_state_dict() 添加 swap_tensors 路径并引入新的扩展点 torch.Tensor.module_load(self, other, assign=False) 来解决这两个限制。当通过上述 __future__ 启用 swap_tensors 路径时,我们可以使用 module_load__torch_function__ 处理程序来对 state_dict 中的值应用自定义转换。该转换的结果将与模块中的参数进行交换。

在下面的示例中,我们将使用上面定义的 MyQuantizedLinearWeight 子类来演示如何在加载 state_dict 时利用这些特性将自定义量化方案应用于线性层的权重。

回想一下,如果 selfother(在此情况下为 paramstate_dict[param_key])是 MyQuantizedLinearWeight 子类,则会调用 module_load__torch_function__ 处理程序。

假设我们期望 state_dict 包含普通张量,而模块包含 MyQuantizedLinearWeight 参数,且我们希望将 state_dict 中的张量转换为该子类。那么我们可以按如下方式为 torch.Tensor.module_load 定义 __torch_function__ 处理程序。

@classmethod
def custom_torch_function(cls, func, types, args=(), kwargs=None):
    kwargs = {} if kwargs is None else kwargs

    if func is torch.Tensor.module_load:
        dest, src = args[0], args[1]
        assert type(dest) == cls and type(src) == torch.Tensor
        return MyQuantizedLinearWeight(src, dest.scale)
    else:
        with torch._C.DisableTorchFunctionSubclass():
                return func(*args, **kwargs)

MyQuantizedLinearWeight.__torch_function__ = custom_torch_function

首先,让我们在 meta 设备上创建一个模型框架,以避免物化存储。我们将模块中的所有权重转换为 MyQuantizedLinearWeight 子类,同时保持偏置(bias)不变。

def fn(m):
    if isinstance(m, nn.Linear):
        requires_grad = m.weight.requires_grad
        m.weight = torch.nn.Parameter(
                    MyQuantizedLinearWeight(m.weight, 0.5), requires_grad=requires_grad
                   )

with torch.device("meta"):
    m = nn.Linear(3, 5)
    m.apply(fn)

然后我们可以加载 state_dict。注意我们使用了 assign=True,因为对于偏置,我们希望保留 state_dict 中张量的属性(例如,我们不希望偏置在加载后仍处于 meta 设备上)。

torch.__future__.set_swap_module_params_on_conversion(True)
print(f"Before: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
print(f"m.state_dict() before load_state_dict():\n {m.state_dict()}")
state_dict = nn.Linear(3, 5).state_dict()
print(f"state_dict:\n {state_dict}")
m.load_state_dict(state_dict, assign=True)
print(f"After: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
print(f"m.state_dict() after load_state_dict():\n {m.state_dict()}")
Before: id(weight)=140458790734320, id(bias)=140457731527024
m.state_dict() before load_state_dict():
 OrderedDict([('weight', MyQuantizedLinearWeight(tensor(..., device='meta', size=(5, 3)), scale=0.5)), ('bias', tensor(..., device='meta', size=(5,)))])
state_dict:
 OrderedDict([('weight', tensor([[-0.2628, -0.0056, -0.5122],
        [ 0.4156,  0.3666,  0.4726],
        [ 0.0364,  0.0934, -0.2108],
        [ 0.1672,  0.4529, -0.0464],
        [ 0.4142, -0.4861,  0.3882]])), ('bias', tensor([ 0.3408, -0.5007, -0.0986,  0.5373, -0.5751]))])
After: id(weight)=140458790734320, id(bias)=140457731527024
m.state_dict() after load_state_dict():
 OrderedDict([('weight', MyQuantizedLinearWeight(tensor([[-0.2628, -0.0056, -0.5122],
        [ 0.4156,  0.3666,  0.4726],
        [ 0.0364,  0.0934, -0.2108],
        [ 0.1672,  0.4529, -0.0464],
        [ 0.4142, -0.4861,  0.3882]]), scale=0.5)), ('bias', tensor([ 0.3408, -0.5007, -0.0986,  0.5373, -0.5751]))])

以上是一个关于如何使用 nn.Module.load_state_dict() 中新扩展点的玩具示例。人们还可以设想其他场景,例如当 state_dict 中存在张量子类而模块中是普通 nn.Parameters/张量时,或者两者都是张量子类时。根据具体用例,我们可以为 module_load 定义 __torch_function__ 处理程序以按需应用转换。

结论#

在本教程中,我们了解了 swap_tensors、在 nn.Module 中保留参数引用的重要性,以及如何使用受 torch.__future__.set_swap_module_params_on_conversion 门控的两个新扩展点。

脚本总运行时间:(0 分 0.016 秒)