注意
转到末尾 下载完整的示例代码。
参数化教程#
创建日期:2021年4月19日 | 最后更新:2024年2月5日 | 最后验证:2024年11月5日
作者: Mario Lezcano
对深度学习模型进行正则化是一项出人意料的挑战性任务。由于被优化函数的复杂性,像罚函数法这样的经典技术在应用于深度模型时往往力不从心。这在处理病态模型时尤其成问题。例如,在长序列上训练的 RNN 和 GAN 就是这类模型。近年来,已提出了许多技术来正则化这些模型并改善它们的收敛性。对于循环模型,有人提出控制 RNN 的循环核的奇异值以使其保持良好病态。例如,可以通过使循环核 正交 来实现这一点。另一种正则化循环模型的方法是通过“权重归一化”。这种方法建议将参数的学习与其范数鐒学习分离开来。为此,将参数除以其 Frobenius 范数,并学习一个单独的参数来编码其范数。GANs 中也提出了类似的正则化方法,称为“谱归一化”。该方法通过将参数除以它们的 谱范数(而不是 Frobenius 范数)来控制网络的 Lipschitz 常数。
所有这些方法都有一个共同的模式:它们都在使用参数之前以适当的方式转换参数。在第一种情况下,它们通过使用将矩阵映射到正交矩阵的函数使其正交。在权重归一化和谱归一化的情况下,它们将原始参数除以其范数。
更一般地说,所有这些示例都使用一个函数来对参数施加额外的结构。换句话说,它们使用一个函数来约束参数。
在本教程中,您将学习如何实现和使用这种模式来约束您的模型。这就像编写自己的 nn.Module 一样简单。
要求: torch>=1.9.0
手动实现参数化#
假设我们想要一个具有对称权重的方形线性层,即具有满足 X = Xᵀ 的权重 X。一种方法是将矩阵的上三角部分复制到其下三角部分。
tensor([[0.2994, 0.1108, 0.7758],
[0.1108, 0.8337, 0.3246],
[0.7758, 0.3246, 0.2519]])
然后,我们可以利用这个思想来实现一个具有对称权重的线性层。
class LinearSymmetric(nn.Module):
def __init__(self, n_features):
super().__init__()
self.weight = nn.Parameter(torch.rand(n_features, n_features))
def forward(self, x):
A = symmetric(self.weight)
return x @ A
该层随后可以像常规线性层一样使用。
layer = LinearSymmetric(3)
out = layer(torch.rand(8, 3))
尽管此实现正确且独立,但它存在一些问题:
它重新实现了该层。我们不得不将线性层实现为
x @ A。对于线性层来说,这问题不大,但想象一下不得不重写 CNN 或 Transformer……它没有分离层和参数化。如果参数化更复杂,我们就必须为想要使用它的每个层重写其代码。
它每次使用该层时都会重新计算参数化。如果我们(想象 RNN 的循环核)在前向传播过程中多次使用该层,它每次调用该层时都会计算相同的
A。
参数化简介#
参数化可以解决所有这些问题以及其他问题。
让我们开始使用 torch.nn.utils.parametrize 重写上面的代码。我们唯一需要做的就是将参数化编写成一个常规的 nn.Module。
这就足够了。有了这个,我们就可以通过执行以下操作将任何常规层转换为对称层:
layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Symmetric())
ParametrizedLinear(
in_features=3, out_features=3, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): Symmetric()
)
)
)
现在,线性层的矩阵是对称的。
A = layer.weight
assert torch.allclose(A, A.T) # A is symmetric
print(A) # Quick visual check
tensor([[-0.0900, 0.2098, 0.2589],
[ 0.2098, -0.3692, -0.5678],
[ 0.2589, -0.5678, 0.0852]], grad_fn=<AddBackward0>)
我们可以对任何其他层做同样的事情。例如,我们可以创建一个具有 斜对称 核的 CNN。我们使用类似的参数化,将上三角部分乘以符号反转后复制到下三角部分。
tensor([[ 0.0000, 0.0053, -0.1311],
[-0.0053, 0.0000, 0.1016],
[ 0.1311, -0.1016, 0.0000]], grad_fn=<SelectBackward0>)
tensor([[ 0.0000, 0.0430, -0.1106],
[-0.0430, 0.0000, -0.0752],
[ 0.1106, 0.0752, 0.0000]], grad_fn=<SelectBackward0>)
检查参数化模块#
当模块被参数化时,我们会发现该模块在三个方面发生了变化:
model.weight现在是一个属性。它有一个新的
module.parametrizations属性。未经参数化的权重已移至
module.parametrizations.weight.original。
在参数化 weight 之后,layer.weight 被转换为一个 Python 属性。每次我们请求 layer.weight 时,这个属性都会计算 parametrization(weight),正如我们在上面实现 LinearSymmetric 时所做的那样。
已注册的参数化存储在模块内的 parametrizations 属性下。
Unparametrized:
Linear(in_features=3, out_features=3, bias=True)
Parametrized:
ParametrizedLinear(
in_features=3, out_features=3, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): Symmetric()
)
)
)
这个 parametrizations 属性是一个 nn.ModuleDict,可以像这样访问:
print(layer.parametrizations)
print(layer.parametrizations.weight)
ModuleDict(
(weight): ParametrizationList(
(0): Symmetric()
)
)
ParametrizationList(
(0): Symmetric()
)
这个 nn.ModuleDict 的每个元素都是一个 ParametrizationList,它的行为类似于 nn.Sequential。这个列表允许我们在一个张量上连接参数化。由于这是一个列表,我们可以通过索引来访问参数化。我们的 Symmetric 参数化就在这里:
print(layer.parametrizations.weight[0])
Symmetric()
我们注意到的另一件事是,如果我们打印参数,我们会看到 weight 参数已被移动。
print(dict(layer.named_parameters()))
{'bias': Parameter containing:
tensor([-0.4116, 0.4618, 0.3337], requires_grad=True), 'parametrizations.weight.original': Parameter containing:
tensor([[ 0.0344, 0.0072, 0.0562],
[ 0.2481, 0.2788, 0.4807],
[ 0.3946, -0.2378, 0.0221]], requires_grad=True)}
它现在位于 layer.parametrizations.weight.original 下。
Parameter containing:
tensor([[ 0.0344, 0.0072, 0.0562],
[ 0.2481, 0.2788, 0.4807],
[ 0.3946, -0.2378, 0.0221]], requires_grad=True)
除了这三个微小的差异外,参数化与我们的手动实现完全相同。
symmetric = Symmetric()
weight_orig = layer.parametrizations.weight.original
print(torch.dist(layer.weight, symmetric(weight_orig)))
tensor(0., grad_fn=<DistBackward0>)
参数化是一等公民#
由于 layer.parametrizations 是一个 nn.ModuleList,这意味着参数化已正确注册为原始模块的子模块。因此,将参数注册到模块的规则同样适用于注册参数化。例如,如果参数化有参数,当调用 model = model.cuda() 时,这些参数将从 CPU 移动到 CUDA。
缓存参数化值#
参数化通过上下文管理器 parametrize.cached() 提供内置的缓存系统。
class NoisyParametrization(nn.Module):
def forward(self, X):
print("Computing the Parametrization")
return X
layer = nn.Linear(4, 4)
parametrize.register_parametrization(layer, "weight", NoisyParametrization())
print("Here, layer.weight is recomputed every time we call it")
foo = layer.weight + layer.weight.T
bar = layer.weight.sum()
with parametrize.cached():
print("Here, it is computed just the first time layer.weight is called")
foo = layer.weight + layer.weight.T
bar = layer.weight.sum()
Computing the Parametrization
Here, layer.weight is recomputed every time we call it
Computing the Parametrization
Computing the Parametrization
Computing the Parametrization
Here, it is computed just the first time layer.weight is called
Computing the Parametrization
连接参数化#
连接两个参数化就像在同一个张量上注册它们一样简单。我们可以使用它来从更简单的参数化创建更复杂的参数化。例如,Cayley 映射 将斜对称矩阵映射到具有正行列式的正交矩阵。我们可以将 Skew 和实现 Cayley 映射的参数化连接起来,以获得具有正交权重的层。
class CayleyMap(nn.Module):
def __init__(self, n):
super().__init__()
self.register_buffer("Id", torch.eye(n))
def forward(self, X):
# (I + X)(I - X)^{-1}
return torch.linalg.solve(self.Id - X, self.Id + X)
layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Skew())
parametrize.register_parametrization(layer, "weight", CayleyMap(3))
X = layer.weight
print(torch.dist(X.T @ X, torch.eye(3))) # X is orthogonal
tensor(1.9881e-07, grad_fn=<DistBackward0>)
这也可以用于修剪参数化模块或重用参数化。例如,矩阵指数将对称矩阵映射到对称正定 (SPD) 矩阵。但矩阵指数也将斜对称矩阵映射到正交矩阵。利用这两个事实,我们可以将之前的参数化重用到我们的优势。
class MatrixExponential(nn.Module):
def forward(self, X):
return torch.matrix_exp(X)
layer_orthogonal = nn.Linear(3, 3)
parametrize.register_parametrization(layer_orthogonal, "weight", Skew())
parametrize.register_parametrization(layer_orthogonal, "weight", MatrixExponential())
X = layer_orthogonal.weight
print(torch.dist(X.T @ X, torch.eye(3))) # X is orthogonal
layer_spd = nn.Linear(3, 3)
parametrize.register_parametrization(layer_spd, "weight", Symmetric())
parametrize.register_parametrization(layer_spd, "weight", MatrixExponential())
X = layer_spd.weight
print(torch.dist(X, X.T)) # X is symmetric
print((torch.linalg.eigvalsh(X) > 0.).all()) # X is positive definite
tensor(2.5723e-07, grad_fn=<DistBackward0>)
tensor(5.2684e-09, grad_fn=<DistBackward0>)
tensor(True)
初始化参数化#
参数化提供了一种初始化它们的方法。如果我们实现一个具有以下签名的 right_inverse 方法:
def right_inverse(self, X: Tensor) -> Tensor
在分配给参数化张量时将使用它。
让我们升级我们对 Skew 类的实现以支持这一点。
我们现在可以初始化一个使用 Skew 参数化的层。
layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Skew())
X = torch.rand(3, 3)
X = X - X.T # X is now skew-symmetric
layer.weight = X # Initialize layer.weight to be X
print(torch.dist(layer.weight, X)) # layer.weight == X
tensor(0., grad_fn=<DistBackward0>)
这个 right_inverse 在连接参数化时按预期工作。为了说明这一点,让我们升级 Cayley 参数化以支持初始化:
class CayleyMap(nn.Module):
def __init__(self, n):
super().__init__()
self.register_buffer("Id", torch.eye(n))
def forward(self, X):
# Assume X skew-symmetric
# (I + X)(I - X)^{-1}
return torch.linalg.solve(self.Id - X, self.Id + X)
def right_inverse(self, A):
# Assume A orthogonal
# See https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
# (A - I)(A + I)^{-1}
return torch.linalg.solve(A + self.Id, self.Id - A)
layer_orthogonal = nn.Linear(3, 3)
parametrize.register_parametrization(layer_orthogonal, "weight", Skew())
parametrize.register_parametrization(layer_orthogonal, "weight", CayleyMap(3))
# Sample an orthogonal matrix with positive determinant
X = torch.empty(3, 3)
nn.init.orthogonal_(X)
if X.det() < 0.:
X[0].neg_()
layer_orthogonal.weight = X
print(torch.dist(layer_orthogonal.weight, X)) # layer_orthogonal.weight == X
tensor(0.1324, grad_fn=<DistBackward0>)
这个初始化步骤可以更简洁地写成:
layer_orthogonal.weight = nn.init.orthogonal_(layer_orthogonal.weight)
此方法的名称来源于我们通常期望 forward(right_inverse(X)) == X。这是重写使用值 X 初始化后的前向传播应返回值 X 的直接方法。实际上,这种约束并未得到严格执行。事实上,有时放松这种关系可能会引起人们的兴趣。例如,考虑以下随机剪枝方法的实现:
class PruningParametrization(nn.Module):
def __init__(self, X, p_drop=0.2):
super().__init__()
# sample zeros with probability p_drop
mask = torch.full_like(X, 1.0 - p_drop)
self.mask = torch.bernoulli(mask)
def forward(self, X):
return X * self.mask
def right_inverse(self, A):
return A
在这种情况下,对于每个矩阵 A,forward(right_inverse(A)) == A 并不成立。只有当矩阵 A 中的零位置与掩码中的零位置相同时才成立。即使那样,如果我们为一个剪枝参数分配一个张量,那么该张量实际上已经被剪枝也就不足为奇了。
layer = nn.Linear(3, 4)
X = torch.rand_like(layer.weight)
print(f"Initialization matrix:\n{X}")
parametrize.register_parametrization(layer, "weight", PruningParametrization(layer.weight))
layer.weight = X
print(f"\nInitialized weight:\n{layer.weight}")
Initialization matrix:
tensor([[0.1694, 0.1887, 0.3677],
[0.4180, 0.1883, 0.1400],
[0.9703, 0.4129, 0.7185],
[0.5919, 0.7431, 0.2885]])
Initialized weight:
tensor([[0.1694, 0.1887, 0.3677],
[0.0000, 0.1883, 0.1400],
[0.9703, 0.0000, 0.7185],
[0.5919, 0.7431, 0.2885]], grad_fn=<MulBackward0>)
移除参数化#
我们可以使用 parametrize.remove_parametrizations() 来移除模块中参数或缓冲区的所有参数化。
layer = nn.Linear(3, 3)
print("Before:")
print(layer)
print(layer.weight)
parametrize.register_parametrization(layer, "weight", Skew())
print("\nParametrized:")
print(layer)
print(layer.weight)
parametrize.remove_parametrizations(layer, "weight")
print("\nAfter. Weight has skew-symmetric values but it is unconstrained:")
print(layer)
print(layer.weight)
Before:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[ 0.5033, -0.5729, 0.0168],
[ 0.0554, 0.1867, 0.5595],
[-0.4279, 0.0124, -0.4335]], requires_grad=True)
Parametrized:
ParametrizedLinear(
in_features=3, out_features=3, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): Skew()
)
)
)
tensor([[ 0.0000, -0.5729, 0.0168],
[ 0.5729, 0.0000, 0.5595],
[-0.0168, -0.5595, 0.0000]], grad_fn=<SubBackward0>)
After. Weight has skew-symmetric values but it is unconstrained:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[ 0.0000, -0.5729, 0.0168],
[ 0.5729, 0.0000, 0.5595],
[-0.0168, -0.5595, 0.0000]], requires_grad=True)
移除参数化时,我们可以选择保留原始参数(即 layer.parametriations.weight.original 中的参数),而不是其参数化版本,方法是将标志 leave_parametrized=False 设置为 True。
layer = nn.Linear(3, 3)
print("Before:")
print(layer)
print(layer.weight)
parametrize.register_parametrization(layer, "weight", Skew())
print("\nParametrized:")
print(layer)
print(layer.weight)
parametrize.remove_parametrizations(layer, "weight", leave_parametrized=False)
print("\nAfter. Same as Before:")
print(layer)
print(layer.weight)
Before:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[ 0.5575, 0.0889, -0.4280],
[-0.1325, -0.5174, 0.2741],
[-0.2705, 0.4845, -0.5173]], requires_grad=True)
Parametrized:
ParametrizedLinear(
in_features=3, out_features=3, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): Skew()
)
)
)
tensor([[ 0.0000, 0.0889, -0.4280],
[-0.0889, 0.0000, 0.2741],
[ 0.4280, -0.2741, 0.0000]], grad_fn=<SubBackward0>)
After. Same as Before:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[ 0.0000, 0.0889, -0.4280],
[ 0.0000, 0.0000, 0.2741],
[ 0.0000, 0.0000, 0.0000]], requires_grad=True)
脚本总运行时间: (0 分钟 0.063 秒)