注意
跳转至页尾 下载完整示例代码。
MaskedTensor 概述#
本教程旨在作为使用 MaskedTensor 的入门指南,并探讨其掩码(masking)语义。
MaskedTensor 是 torch.Tensor 的扩展,它使用户能够
使用任意掩码语义(例如:变长张量、nan* 运算符等)
区分 0 梯度和 NaN 梯度
实现各种稀疏应用(见下文教程)
如需深入了解什么是 MaskedTensor,请参阅 torch.masked 文档。
使用 MaskedTensor#
在本节中,我们将讨论如何使用 MaskedTensor,包括如何构建、访问数据和掩码,以及索引和切片操作。
准备工作#
首先进行教程所需的设置
# Disable prototype warnings and such
构造#
构建 MaskedTensor 有几种不同的方法
第一种方法是直接调用 MaskedTensor 类
第二种(也是我们推荐的)方法是使用
masked.masked_tensor()和masked.as_masked_tensor()工厂函数,它们类似于torch.tensor()和torch.as_tensor()
在本教程中,我们假设已包含导入语句:from torch.masked import masked_tensor。
访问数据和掩码#
可以通过以下方式访问 MaskedTensor 中的底层字段
MaskedTensor.get_data()函数MaskedTensor.get_mask()函数。请记住,True表示“已指定”或“有效”,而False表示“未指定”或“无效”。
通常,返回的底层数据在未指定条目中可能无效,因此我们建议当用户需要一个不包含任何掩码条目的 Tensor 时,使用 MaskedTensor.to_tensor()(如上所示)来返回一个已填充值的 Tensor。
索引和切片#
MaskedTensor 是 Tensor 的子类,这意味着它继承了与 torch.Tensor 相同的索引和切片语义。以下是一些常见索引和切片模式的示例
# float is used for cleaner visualization when being printed
为什么 MaskedTensor 很有用?#
由于 MaskedTensor 将已指定值和未指定值视为一等公民,而不是事后补救(通过填充值、NaN 等),因此它能够解决普通 Tensor 无法解决的几个缺点;实际上,MaskedTensor 的诞生在很大程度上就是为了解决这些反复出现的问题。
下面我们将讨论一些 PyTorch 今天尚未解决的最常见问题,并说明 MaskedTensor 如何解决这些问题。
区分 0 梯度和 NaN 梯度#
torch.Tensor 遇到的一个问题是无法区分未定义的梯度(NaN)和实际上为 0 的梯度。由于 PyTorch 没有标记值是“已指定/有效”还是“未指定/无效”的方法,它被迫依赖 NaN 或 0(取决于具体用例),导致语义不可靠,因为许多操作无法正确处理 NaN 值。更令人困惑的是,有时根据操作顺序的不同,梯度可能会发生变化(例如,取决于 NaN 值在操作链中出现的早晚)。
MaskedTensor 是解决此问题的完美方案!
torch.where#
在 Issue 10729 中,我们注意到使用 torch.where() 时操作顺序可能会产生影响,因为我们难以区分 0 是真实的 0 还是来自未定义梯度。因此,我们保持一致并对结果进行掩码处理。
当前结果
MaskedTensor 结果
这里的梯度仅提供给选定的子集。实际上,这会将 where 的梯度更改为对元素进行掩码处理,而不是将它们设置为零。
另一个 torch.where#
Issue 52248 是另一个例子。
当前结果
MaskedTensor 结果
这个问题类似(甚至链接到下面的下一个问题),它表达了对意外行为的沮丧,原因在于无法区分“无梯度”和“零梯度”,这反过来使得处理其他操作变得难以推理。
使用掩码时,x/0 产生 NaN 梯度#
在 Issue 4132 中,用户建议 x.grad 应该是 [0, 1] 而不是 [nan, 1],而 MaskedTensor 通过完全屏蔽梯度使这一点变得非常明确。
当前结果
MaskedTensor 结果
torch.nansum() 和 torch.nanmean()#
在 Issue 67180 中,梯度没有被正确计算(一个长期存在的问题),而 MaskedTensor 处理得当。
当前结果
MaskedTensor 结果
安全 Softmax#
安全 Softmax 是另一个经常出现的 问题。简而言之,如果整个批次都被“屏蔽”或完全由填充组成(在 Softmax 的情况下,这相当于被设置为 -inf),则会导致 NaN,从而可能导致训练发散。
幸运的是,MaskedTensor 解决了这个问题。考虑此设置
例如,我们想要计算 dim=0 上的 Softmax。请注意,第二列是“不安全”的(即完全被屏蔽),因此当计算 Softmax 时,结果将产生 0/0 = nan,因为 exp(-inf) = 0。然而,我们真正想要的是梯度被屏蔽掉,因为它们是未指定的,对于训练来说是无效的。
PyTorch 结果
MaskedTensor 结果
实现缺失的 torch.nan* 运算符#
在 Issue 61474 中,有人请求添加额外的运算符来覆盖各种 torch.nan* 应用,例如 torch.nanmax、torch.nanmin 等。
通常,这些问题更自然地属于掩码语义,因此我们建议使用 MaskedTensor,而不是引入额外的运算符。由于 nanmean 已经落地,我们可以将其作为比较点。
# z is just y with the zeros replaced with nan's
# MaskedTensor successfully ignores the 0's
在上面的示例中,我们构建了一个 y,并希望在忽略零的情况下计算序列的平均值。torch.nanmean 可以用来实现这一点,但我们没有实现其余的 torch.nan* 操作。MaskedTensor 通过能够使用基础操作解决了这个问题,并且我们已经支持了 issue 中列出的其他操作。例如
确实,忽略 0 时,最小参数的索引是索引 1 处的 1。
MaskedTensor 还可以在数据完全被屏蔽时支持缩减操作,这等同于上述数据 Tensor 完全是 nan 的情况。nanmean 会返回 nan(一个有歧义的返回值),而 MaskedTensor 则会更准确地指示出一个被屏蔽的结果。
这是一个类似于安全 Softmax 的问题,即当我们需要一个未定义值时出现了 0/0 = nan。
结论#
在本教程中,我们介绍了什么是 MaskedTensor,展示了如何使用它们,并通过一系列它们帮助解决的示例和问题激发了它们的价值。
进一步阅读#
要继续深入了解,您可以查看我们的 MaskedTensor 稀疏性教程,了解 MaskedTensor 如何实现稀疏性以及我们目前支持的不同存储格式。
# %%%%%%RUNNABLE_CODE_REMOVED%%%%%%
脚本总运行时间:(0 分 0.002 秒)