RMSNorm#
- class torch.nn.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[source]#
对输入的小批量应用均方根层归一化。
此层实现了论文 Root Mean Square Layer Normalization 中描述的操作。
RMS 是在最后
D
个维度上计算的,其中D
是normalized_shape
的维度。例如,如果normalized_shape
是(3, 5)
(一个二维形状),则 RMS 是在输入的最后 2 个维度上计算的。- 参数
normalized_shape (int 或 list 或 torch.Size) –
input shape from an expected input of size
If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size.
eps (Optional[float]) – 用于数值稳定性的分母的加法值。默认值:
torch.finfo(x.dtype).eps
elementwise_affine (bool) – 一个布尔值,当设置为
True
时,此模块具有可学习的逐元素仿射参数,并初始化为 1(用于权重)。默认值:True
。
- 形状
输入:
输出:(与输入形状相同)
示例
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)