跳过模块参数初始化#
简介#
当一个模块被创建时,其可学习参数会根据与模块类型关联的默认初始化方案进行初始化。例如,torch.nn.Linear
模块的 weight 参数是从 uniform(-1/sqrt(in_features), 1/sqrt(in_features)) 分布中初始化的。如果需要其他初始化方案,传统上需要先实例化模块,然后再重新初始化参数。
from torch import nn
# Initializes weight from the default distribution: uniform(-1/sqrt(10), 1/sqrt(10)).
m = nn.Linear(10, 5)
# Re-initialize weight from a different distribution.
nn.init.orthogonal_(m.weight)
在这种情况下,构造期间进行的初始化是浪费的计算,而且如果 weight 参数很大,可能会产生不小的开销。
跳过初始化#
现在可以在模块构造期间跳过参数初始化,避免浪费计算。这可以通过使用 torch.nn.utils.skip_init()
函数轻松完成。
from torch import nn
from torch.nn.utils import skip_init
m = skip_init(nn.Linear, 10, 5)
# Example: Do custom, non-default parameter initialization.
nn.init.orthogonal_(m.weight)
这可以应用于满足下面“更新模块以支持跳过初始化”部分中描述的条件的任何模块。请注意,torch.nn 提供的所有模块都满足这些条件,因此支持跳过初始化。
更新模块以支持跳过初始化#
由于 torch.nn.utils.skip_init()
的实现方式(请参阅“实现细节”),模块必须满足两个要求才能与该函数兼容。通过遵循这些要求,您可以为自定义模块选择启用参数初始化跳过功能。
1. 模块的构造函数必须接受一个 device 关键字参数,该参数会传递给构造期间创建的任何参数或缓冲区。
2. 模块在构造函数中不得对参数或缓冲区执行任何计算,除了初始化(即 torch.nn.init 中的函数)。
以下示例演示了一个已更新以支持 device 关键字参数的模块,方法是将其传递给任何创建的参数、缓冲区或子模块。
import torch
from torch import nn
class MyModule(torch.nn.Module):
def __init__(self, foo, bar, device=None):
super().__init__()
# ==== Case 1: Module creates parameters directly. ====
# Pass device along to any created parameters.
self.param1 = nn.Parameter(torch.empty((foo, bar), device=device))
self.register_parameter('param2', nn.Parameter(torch.empty(bar, device=device)))
# To ensure support for the meta device, avoid using ops except those in
# torch.nn.init on parameters in your module's constructor.
with torch.no_grad():
nn.init.kaiming_uniform_(self.param1)
nn.init.uniform_(self.param2)
# ==== Case 2: Module creates submodules. ====
# Pass device along recursively. All submodules will need to support
# them as well; this is the case for all torch.nn provided modules.
self.fc = nn.Linear(bar, 5, device=device)
# This also works with containers.
self.linears = nn.Sequential(
nn.Linear(5, 5, device=device),
nn.Linear(5, 1, device=device)
)
# ==== Case 3: Module creates buffers. ====
# Pass device along during buffer tensor creation.
self.register_buffer('some_buffer', torch.ones(7, device=device))
...
实现细节#
在后台,torch.nn.utils.skip_init()
函数是基于一个两步模式实现的。
# 1. Initialize module on the meta device; all torch.nn.init ops have
# no-op behavior on the meta device.
m = nn.Linear(10, 5, device='meta')
# 2. Materialize an uninitialized (empty) form of the module on the CPU device.
# The result of this is a module instance with uninitialized parameters.
m.to_empty(device='cpu')
它的工作原理是将模块实例化到一个“meta”设备上,该设备具有张量形状信息,但不会分配任何存储。 torch.nn.init 操作为这个 meta 设备进行了特殊实现,使其成为无操作行为。这导致参数初始化逻辑基本上被跳过。
请注意,此模式仅适用于在构造期间正确支持 device 关键字参数的模块,如“更新模块以支持跳过初始化”中所述。