torch.nn.utils.parametrizations.spectral_norm#
- torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[来源]#
对给定模块中的参数应用谱归一化。
当应用于向量时,它简化为
谱归一化(Spectral normalization)通过减小模型的 Lipschitz 常数,稳定了生成对抗网络(GAN)中判别器(评论家)的训练。 是通过在每次访问权重时执行一次幂迭代(power method)来近似的。如果权重张量的维度大于 2,则在幂迭代方法中将其重塑为二维以获得谱范数。
请参阅 Spectral Normalization for Generative Adversarial Networks。
注意
此函数是使用
register_parametrization()中的参数化功能实现的。它是torch.nn.utils.spectral_norm()的重新实现。注意
当注册此约束时,与最大奇异值相关的奇异向量会被估计,而不是随机采样。每当在 training(训练)模式下访问张量时,这些向量都会通过执行
n_power_iterations次幂迭代进行更新。注意
如果在移除时 _SpectralNorm 模块(即 module.parametrization.weight[idx])处于训练模式,它将执行另一次幂迭代。如果您希望避免此次迭代,请在移除前将模块设置为评估(eval)模式。
- 参数:
- 返回:
已在指定权重上注册了新参数化的原始模块。
- 返回类型:
示例
>>> snm = spectral_norm(nn.Linear(20, 40)) >>> snm ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _SpectralNorm() ) ) ) >>> torch.linalg.matrix_norm(snm.weight, 2) tensor(1.0081, grad_fn=<AmaxBackward0>)