评价此页

torch.nn.utils.parametrizations.spectral_norm#

torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[来源]#

对给定模块中的参数应用谱归一化。

WSN=Wσ(W),σ(W)=maxh:h0Wh2h2\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}

当应用于向量时,它简化为

xSN=xx2\mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}

谱归一化(Spectral normalization)通过减小模型的 Lipschitz 常数,稳定了生成对抗网络(GAN)中判别器(评论家)的训练。σ\sigma 是通过在每次访问权重时执行一次幂迭代(power method)来近似的。如果权重张量的维度大于 2,则在幂迭代方法中将其重塑为二维以获得谱范数。

请参阅 Spectral Normalization for Generative Adversarial Networks

注意

此函数是使用 register_parametrization() 中的参数化功能实现的。它是 torch.nn.utils.spectral_norm() 的重新实现。

注意

当注册此约束时,与最大奇异值相关的奇异向量会被估计,而不是随机采样。每当在 training(训练)模式下访问张量时,这些向量都会通过执行 n_power_iterations幂迭代进行更新。

注意

如果在移除时 _SpectralNorm 模块(即 module.parametrization.weight[idx])处于训练模式,它将执行另一次幂迭代。如果您希望避免此次迭代,请在移除前将模块设置为评估(eval)模式。

参数:
  • module (nn.Module) – 包含的模块

  • name (str, optional) – 权重参数的名称。默认值:"weight"

  • n_power_iterations (int, optional) – 计算谱范数时进行的幂迭代次数。默认值:1

  • eps (float, optional) – 用于计算范数时数值稳定性的 epsilon 值。默认值:1e-12

  • dim (int, optional) – 对应于输出数量的维度。默认值:0;但对于作为 ConvTranspose{1,2,3}d 实例的模块,默认值为 1

返回:

已在指定权重上注册了新参数化的原始模块。

返回类型:

模块

示例

>>> snm = spectral_norm(nn.Linear(20, 40))
>>> snm
ParametrizedLinear(
  in_features=20, out_features=40, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): _SpectralNorm()
    )
  )
)
>>> torch.linalg.matrix_norm(snm.weight, 2)
tensor(1.0081, grad_fn=<AmaxBackward0>)