评价此页

torch.nn.utils.parametrizations.spectral_norm#

torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[source]#

对给定模块中的参数应用谱归一化。

WSN=Wσ(W),σ(W)=maxh:h0Wh2h2\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}

When applied on a vector, it simplifies to

xSN=xx2\mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}

Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant of the model. σ\sigma is approximated performing one iteration of the power method every time the weight is accessed. If the dimension of the weight tensor is greater than 2, it is reshaped to 2D in power iteration method to get spectral norm.

See Spectral Normalization for Generative Adversarial Networks .

注意

This function is implemented using the parametrization functionality in register_parametrization(). It is a reimplementation of torch.nn.utils.spectral_norm().

注意

When this constraint is registered, the singular vectors associated to the largest singular value are estimated rather than sampled at random. These are then updated performing n_power_iterations of the power method whenever the tensor is accessed with the module on training mode.

注意

If the _SpectralNorm module, i.e., module.parametrization.weight[idx], is in training mode on removal, it will perform another power iteration. If you’d like to avoid this iteration, set the module to eval mode before its removal.

参数
  • module (nn.Module) – containing module

  • name (str, optional) – name of weight parameter. Default: "weight".

  • n_power_iterations (int, optional) – number of power iterations to calculate spectral norm. Default: 1.

  • eps (float, optional) – epsilon for numerical stability in calculating norms. Default: 1e-12.

  • dim (int, optional) – dimension corresponding to number of outputs. Default: 0, except for modules that are instances of ConvTranspose{1,2,3}d, when it is 1

返回

The original module with a new parametrization registered to the specified weight

返回类型

模块

示例

>>> snm = spectral_norm(nn.Linear(20, 40))
>>> snm
ParametrizedLinear(
  in_features=20, out_features=40, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): _SpectralNorm()
    )
  )
)
>>> torch.linalg.matrix_norm(snm.weight, 2)
tensor(1.0081, grad_fn=<AmaxBackward0>)