评价此页

torch.nanmedian#

torch.nanmedian(input) Tensor#

返回 input 中所有值的中位数,会忽略 NaN 值。

input 中没有 NaN 值时,此函数与 torch.median() 完全相同。当 input 包含一个或多个 NaN 值时,torch.median() 将始终返回 NaN,而此函数将返回 input 中非 NaN 元素的 median。如果 input 中的所有元素都是 NaN,则它也将返回 NaN

参数

input (Tensor) – 输入张量。

示例

>>> a = torch.tensor([1, float('nan'), 3, 2])
>>> a.median()
tensor(nan)
>>> a.nanmedian()
tensor(2.)
torch.nanmedian(input, dim=-1, keepdim=False, *, out=None)

返回一个命名元组 (values, indices),其中 values 包含 inputdim 维度上每行的中位数(忽略 NaN 值),而 indices 包含在 dim 维度上找到的中位数值的索引。

当缩减的行中没有 NaN 值时,此函数与 torch.median() 相同。当缩减的行包含一个或多个 NaN 值时,torch.median() 将始终将其缩减为 NaN,而此函数将将其缩减为非 NaN 元素的 असतात。如果缩减行中的所有元素均为 NaN,则也会将其缩减为 NaN

参数
  • input (Tensor) – 输入张量。

  • dim (int, optional) – 要缩减的维度。如果为 None,则缩减所有维度。

  • keepdim (bool, optional) – 输出张量是否保留 dim。默认为 False

关键字参数

out ((Tensor, Tensor), optional) – 第一个张量将填充中位数,第二个张量(必须是 long 类型)将填充其在 inputdim 维度中的索引。

示例

>>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]])
>>> a
tensor([[2., 3., 1.],
        [nan, 1., nan]])
>>> a.median(0)
torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1]))
>>> a.nanmedian(0)
torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0]))