torch.nn.utils.get_total_norm#
- torch.nn.utils.get_total_norm(tensors, norm_type=2.0, error_if_nonfinite=False, foreach=None)[源代码]#
计算一个张量可迭代对象的范数。
范数是通过对单个张量的范数进行计算得到的,就好像将单个张量的范数拼接成一个单独的向量一样。
- 参数
tensors (Iterable[Tensor] or Tensor) – 一个张量可迭代对象或单个张量,将对其进行归一化。
norm_type (float) – 使用的 p-范数的类型。可以是
'inf'
表示无穷范数。error_if_nonfinite (bool) – 如果为 True,则在
tensors
的总范数是nan
、inf
或-inf
时抛出错误。默认为False
。foreach (bool) – 使用更快的基于 foreach 的实现。如果为
None
,则对 CUDA 和 CPU 原生张量使用 foreach 实现,而对其他设备类型则静默回退到慢速实现。默认为None
。
- 返回
张量的总范数(视为一个单独的向量)。
- 返回类型