RMSNorm#
- class torch.nn.modules.normalization.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[source]#
对输入的小批量应用均方根层归一化。
This layer implements the operation as described in the paper Root Mean Square Layer Normalization
The RMS is taken over the last
D
dimensions, whereD
is the dimension ofnormalized_shape
. For example, ifnormalized_shape
is(3, 5)
(a 2-dimensional shape), the RMS is computed over the last 2 dimensions of the input.- 参数
- 形状
输入:
输出: (与输入形状相同)
示例
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)