torch.nn.utils.parametrize.register_parametrization#
- torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization, *, unsafe=False)[源代码]#
将参数化注册到一个模块的张量上。
为简单起见,假设
tensor_name="weight"
。当访问module.weight
时,模块将返回参数化版本parametrization(module.weight)
。如果原始张量需要梯度,则反向传播将通过parametrization
进行微分,优化器将相应地更新张量。模块第一次注册参数化时,此函数将向模块添加一个类型为
ParametrizationList
的属性parametrizations
。张量
weight
上的参数化列表将可以在module.parametrizations.weight
下访问。原始张量将可以在
module.parametrizations.weight.original
下访问。可以通过在同一属性上注册多个参数化来连接参数化。
注册的参数化的训练模式会进行更新,以匹配宿主模块的训练模式。
参数化参数和缓冲区具有内置缓存系统,可以使用上下文管理器
cached()
激活。参数化可以有一个可选的实现方法,签名如下:
def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]
当注册第一个参数化时,将调用此方法处理未参数化的张量,以计算原始张量的初始值。如果未实现此方法,则原始张量就是未参数化的张量。
如果一个张量上注册的所有参数化都实现了 right_inverse,则可以通过赋值来初始化参数化张量,如下面的示例所示。
第一个参数化可以依赖于多个输入。这可以通过从
right_inverse
返回一个张量元组来实现(参见下方RankOne
参数化的示例实现)。在这种情况下,无约束张量也位于
module.parametrizations.weight
下,名称分别为original0
、original1
,等等。注意
如果 unsafe=False(默认值),则 forward 和 right_inverse 方法都将被调用一次,以执行一系列一致性检查。如果 unsafe=True,则将在张量未参数化时调用 right_inverse,否则将不调用任何方法。
注意
在大多数情况下,
right_inverse
是一个函数,使得forward(right_inverse(X)) == X
(参见 右逆)。有时,当参数化不是满射时,放宽此限制可能是合理的。警告
如果一个参数化依赖于多个输入,
register_parametrization()
将会注册一些新的参数。如果此类参数化在优化器创建后注册,则需要手动将这些新参数添加到优化器中。请参阅torch.Optimizer.add_param_group()
。- 参数
- 关键字参数
unsafe (bool) – 一个布尔标志,表示参数化是否可能更改张量的 dtype 和形状。默认值:False 警告:注册时未检查参数化的一致性。请自行承担启用此标志的风险。
- 引发
ValueError – 如果模块没有名为
tensor_name
的参数或缓冲区- 返回类型
示例
>>> import torch >>> import torch.nn as nn >>> import torch.nn.utils.parametrize as P >>> >>> class Symmetric(nn.Module): >>> def forward(self, X): >>> return X.triu() + X.triu(1).T # Return a symmetric matrix >>> >>> def right_inverse(self, A): >>> return A.triu() >>> >>> m = nn.Linear(5, 5) >>> P.register_parametrization(m, "weight", Symmetric()) >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric True >>> A = torch.rand(5, 5) >>> A = A + A.T # A is now symmetric >>> m.weight = A # Initialize the weight to be the symmetric matrix A >>> print(torch.allclose(m.weight, A)) True
>>> class RankOne(nn.Module): >>> def forward(self, x, y): >>> # Form a rank 1 matrix multiplying two vectors >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) >>> >>> def right_inverse(self, Z): >>> # Project Z onto the rank 1 matrices >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) >>> # Return rescaled singular vectors >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt >>> >>> linear_rank_one = P.register_parametrization( ... nn.Linear(4, 4), "weight", RankOne() ... ) >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) 1