torch.nn.utils.parametrizations.spectral_norm#
- torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[源代码]#
对给定模块中的参数应用谱归一化。
当应用于向量时,它简化为
谱归一化通过降低模型的Lipschitz常数来稳定生成对抗网络(GAN)中判别器(critic)的训练。是利用幂法进行一次迭代来近似的,每次访问权重时都会进行。如果权重张量的维度大于2,则在幂法中会将其重塑为2D以获得谱范数。
请参阅 Spectral Normalization for Generative Adversarial Networks。
注意
此函数使用
register_parametrization()
中的参数化功能来实现。它是torch.nn.utils.spectral_norm()
的重新实现。注意
当注册此约束时,将估计与最大奇异值相关的奇异向量,而不是随机采样。然后,当模块在训练模式下访问张量时,会通过进行
n_power_iterations
次 幂法 来更新它们。注意
如果 _SpectralNorm 模块,即 module.parametrization.weight[idx],在移除时处于训练模式,它将执行另一次幂迭代。如果您想避免此迭代,请在移除模块之前将其设置为评估模式。
- 参数
- 返回
注册了新参数化的原始模块,该参数化已应用于指定的权重
- 返回类型
示例
>>> snm = spectral_norm(nn.Linear(20, 40)) >>> snm ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _SpectralNorm() ) ) ) >>> torch.linalg.matrix_norm(snm.weight, 2) tensor(1.0081, grad_fn=<AmaxBackward0>)