注意
转到末尾 下载完整示例代码。
MaskedTensor 概述#
本教程旨在作为使用 MaskedTensors 的起点,并讨论其掩码语义。
MaskedTensor 作为 torch.Tensor 的扩展,使用户能够
使用任何掩码语义(例如,可变长度张量、NaN* 运算符等)
区分 0 和 NaN 梯度
各种稀疏应用(请参阅下面的教程)
有关 MaskedTensors 的更详细介绍,请参阅 torch.masked 文档。
使用 MaskedTensor#
在本节中,我们将讨论如何使用 MaskedTensor,包括如何构造、访问数据和掩码,以及索引和切片。
准备工作#
我们将首先进行教程所需的必要设置
# Disable prototype warnings and such
构造#
有几种不同的方法可以构造 MaskedTensor
第一种方法是直接调用 MaskedTensor 类
第二种方法(也是我们推荐的方法)是使用
masked.masked_tensor()和masked.as_masked_tensor()工厂函数,它们分别类似于torch.tensor()和torch.as_tensor()
在本教程中,我们将假定导入语句为:from torch.masked import masked_tensor。
访问数据和掩码#
可以通过以下方式访问 MaskedTensor 中的底层字段
函数
MaskedTensor.get_data()函数
MaskedTensor.get_mask()。请记住,True表示“已指定”或“有效”,而False表示“未指定”或“无效”。
通常,返回的底层数据在未指定条目中可能无效,因此我们建议当用户需要一个没有任何掩码条目的张量时,应使用 MaskedTensor.to_tensor()(如上所示)来返回一个填充值的张量。
索引和切片#
MaskedTensor 是一个 Tensor 子类,这意味着它继承了与 torch.Tensor 相同的索引和切片语义。以下是一些常见索引和切片模式的示例
# float is used for cleaner visualization when being printed
为什么 MaskedTensor 有用?#
由于 MaskedTensor 将指定值和未指定值视为一等公民,而不是事后诸葛亮(通过填充值、NaN 等),因此它能够解决常规张量无法解决的几个不足之处;事实上,MaskedTensor 的诞生很大程度上就是由于这些反复出现的问题。
下面,我们将讨论 PyTorch 中目前仍未解决的一些最常见问题,并说明 MaskedTensor 如何解决这些问题。
区分 0 和 NaN 梯度#
torch.Tensor 遇到的一个问题是无法区分未定义的梯度 (NaN) 与实际为 0 的梯度。由于 PyTorch 没有一种方法可以将值标记为已指定/有效或未指定/无效,因此它被迫依赖 NaN 或 0(取决于用例),这会导致不稳定的语义,因为许多操作不适合正确处理 NaN 值。更令人困惑的是,有时根据操作顺序,梯度可能会有所不同(例如,取决于链式操作中 NaN 值出现的早晚)。
MaskedTensor 是完美的解决方案!
torch.where#
在 Issue 10729 中,我们注意到在使用 torch.where() 时,操作顺序可能会影响结果,因为我们难以区分 0 是真实值还是来自未定义梯度的值。因此,我们保持一致,掩码掉结果
当前结果
MaskedTensor 结果
这里的梯度仅提供给选定的子集。实际上,这改变了 where 的梯度,使其掩码掉元素而不是将它们设置为零。
另一个 torch.where#
Issue 52248 是另一个例子。
当前结果
MaskedTensor 结果
此问题与以下问题类似(甚至链接到下面的下一个问题),它表达了由于无法区分“无梯度”与“零梯度”而对意外行为感到沮丧,这反过来又使得处理其他操作变得难以理解。
使用掩码时,x/0 会产生 NaN grad#
在 Issue 4132 中,用户提出 x.grad 应该为 [0, 1] 而不是 [nan, 1],而 MaskedTensor 通过完全掩码掉梯度使这一点非常清楚。
当前结果
MaskedTensor 结果
torch.nansum() 和 torch.nanmean()#
在 Issue 67180 中,梯度未正确计算(一个长期存在的问题),而 MaskedTensor 则正确处理了这个问题。
当前结果
MaskedTensor 结果
安全 Softmax#
安全 Softmax 是 另一个经常出现的问题 的绝佳示例。简而言之,如果整个批次被“掩码掉”或完全由填充组成(在 Softmax 的情况下,这会转化为设置为 -inf),那么这将导致 NaN,从而可能导致训练发散。
幸运的是,MaskedTensor 解决了这个问题。考虑以下设置
例如,我们想沿 dim=0 计算 Softmax。请注意,第二列是“不安全的”(即完全掩码掉),因此在计算 Softmax 时,结果将为 0/0 = nan,因为 exp(-inf) = 0。然而,我们真正想要的是掩码掉梯度,因为它们未指定并且对训练无效。
PyTorch 结果
MaskedTensor 结果
实现缺失的 torch.nan* 运算符#
在 Issue 61474 中,有一个要求添加其他运算符来覆盖各种 torch.nan* 应用,例如 torch.nanmax、torch.nanmin 等。
通常,这些问题更自然地适用于掩码语义,因此我们建议使用 MaskedTensor,而不是引入额外的运算符。由于 nanmean 已合并,我们可以将其作为比较点
# z is just y with the zeros replaced with nan's
# MaskedTensor successfully ignores the 0's
在上面的示例中,我们构造了一个 y,并希望计算该序列的平均值,同时忽略零。可以使用 torch.nanmean 来执行此操作,但我们没有 torch.nan* 操作的其他实现。 MaskedTensor 通过使用基本操作来解决此问题,并且我们已经支持了问题中列出的其他操作。例如
实际上,忽略零时最小参数的索引是索引 1 中的 1。
MaskedTensor 还可以支持数据完全掩码掉时的归约,这相当于上面数据张量完全为 nan 的情况。 nanmean 将返回 nan(一个模糊的返回值),而 MaskedTensor 将更准确地指示已掩码掉的结果。
这与安全 Softmax 的问题类似,即 0/0 = nan,而我们实际上想要的是一个未定义的值。
结论#
在本教程中,我们介绍了 MaskedTensors 是什么,演示了如何使用它们,并通过一系列示例和它们已帮助解决的问题来阐述它们的价值。
进一步阅读#
要继续学习更多内容,您可以找到我们的 MaskedTensor Sparsity 教程,了解 MaskedTensor 如何实现稀疏性以及我们目前支持的不同存储格式。
# %%%%%%RUNNABLE_CODE_REMOVED%%%%%%
脚本总运行时间:(0 分 0.002 秒)