评价此页

MaskedTensor 概述#

本教程旨在作为使用 MaskedTensor 的起点,并讨论其掩码语义。

MaskedTensor 是 torch.Tensor 的扩展,它为用户提供了以下能力:

  • 使用任何掩码语义(例如,可变长度张量、nan* 运算符等)

  • 区分 0 和 NaN 梯度

  • 各种稀疏应用(参见下面的教程)

有关 MaskedTensor 的更详细介绍,请参阅 torch.masked 文档

使用 MaskedTensor#

在本节中,我们将讨论如何使用 MaskedTensor,包括如何构造、访问数据和掩码,以及索引和切片。

准备工作#

我们首先进行教程的必要设置

# Disable prototype warnings and such

构造#

有几种不同的方法可以构造 MaskedTensor

  • 第一种方法是直接调用 MaskedTensor 类

  • 第二种(也是我们推荐的方式)是使用 masked.masked_tensor()masked.as_masked_tensor() 工厂函数,它们类似于 torch.tensor()torch.as_tensor()

在本教程中,我们假设导入行:from torch.masked import masked_tensor

访问数据和掩码#

MaskedTensor 中的底层字段可以通过以下方式访问

  • MaskedTensor.get_data() 函数

  • MaskedTensor.get_mask() 函数。请注意,True 表示“指定”或“有效”,而 False 表示“未指定”或“无效”。

通常,返回的底层数据可能在未指定的条目中无效,因此我们建议当用户需要一个没有 masked 条目的 Tensor 时,他们使用 MaskedTensor.to_tensor()(如上所示)返回一个填充了值的 Tensor。

索引和切片#

MaskedTensor 是 Tensor 的子类,这意味着它继承了与 torch.Tensor 相同的索引和切片语义。以下是一些常见索引和切片模式的示例

# float is used for cleaner visualization when being printed

为什么 MaskedTensor 有用?#

由于 MaskedTensor 将指定和未指定值作为一流公民处理,而不是事后添加(通过填充值、NaN 等),它能够解决常规 Tensor 无法解决的几个缺点;事实上,MaskedTensor 的诞生很大程度上是由于这些反复出现的问题。

下面,我们将讨论目前 PyTorch 中一些最常见且尚未解决的问题,并说明 MaskedTensor 如何解决这些问题。

区分 0 和 NaN 梯度#

torch.Tensor 遇到的一个问题是无法区分未定义(NaN)的梯度和实际为 0 的梯度。由于 PyTorch 没有一种方法来将值标记为指定/有效或未指定/无效,它被迫依赖于 NaN 或 0(取决于用例),导致语义不可靠,因为许多操作并非旨在正确处理 NaN 值。更令人困惑的是,有时根据操作的顺序,梯度可能会有所不同(例如,取决于 NaN 值在操作链中出现的时间)。

MaskedTensor 是解决此问题的完美方案!

torch.where#

Issue 10729 中,我们注意到在使用 torch.where() 时操作顺序可能很重要,因为我们难以区分 0 是真实的 0 还是来自未定义梯度。因此,我们保持一致并屏蔽结果

当前结果

MaskedTensor 结果

这里的梯度只提供给选定的子集。实际上,这将 where 的梯度更改为屏蔽元素而不是将其设置为零。

另一个 torch.where#

Issue 52248 是另一个例子。

当前结果

MaskedTensor 结果

此问题类似(甚至链接到下面的下一个问题),它表达了对意外行为的沮丧,因为无法区分“无梯度”与“零梯度”,这反过来使得使用其他操作变得难以理解。

使用掩码时,x/0 产生 NaN 梯度#

Issue 4132 中,用户建议 x.grad 应该是 [0, 1] 而不是 [nan, 1],而 MaskedTensor 通过完全屏蔽梯度使其非常清晰。

当前结果

MaskedTensor 结果

torch.nansum()torch.nanmean()#

Issue 67180 中,梯度计算不正确(一个长期存在的问题),而 MaskedTensor 正确处理了它。

当前结果

MaskedTensor 结果

安全 Softmax#

安全 Softmax 是 一个频繁出现的问题 的另一个很好的例子。简而言之,如果有一个完整的批次被“屏蔽掉”或完全由填充组成(在 softmax 的情况下,这意味着设置为 -inf),那么这将导致 NaNs,从而可能导致训练发散。

幸运的是,MaskedTensor 解决了这个问题。考虑以下设置

例如,我们想沿着 dim=0 计算 softmax。请注意,第二列是“不安全”的(即完全被屏蔽),因此当计算 softmax 时,结果将产生 0/0 = nan,因为 exp(-inf) = 0。然而,我们真正希望的是梯度被屏蔽,因为它们是未指定的,并且对训练无效。

PyTorch 结果

MaskedTensor 结果

实现缺失的 torch.nan* 运算符#

Issue 61474 中,有一个请求是添加额外的运算符以涵盖各种 torch.nan* 应用,例如 torch.nanmaxtorch.nanmin 等。

通常,这些问题更自然地适用于掩码语义,因此我们建议使用 MaskedTensor 来代替引入额外的运算符。由于 nanmean 已经实现,我们可以将其用作比较点

# z is just y with the zeros replaced with nan's
# MaskedTensor successfully ignores the 0's

在上面的示例中,我们构造了一个 y,并希望在忽略零的情况下计算序列的均值。torch.nanmean 可以用来做到这一点,但我们没有其余 torch.nan* 操作的实现。MaskedTensor 通过能够使用基本操作来解决此问题,并且我们已经支持问题中列出的其他操作。例如

事实上,忽略 0 时最小参数的索引是索引 1 中的 1。

MaskedTensor 还可以支持数据完全被掩码时的 reductions,这等同于上述数据张量完全是 nan 的情况。nanmean 将返回 nan(一个模糊的返回值),而 MaskedTensor 将更准确地指示一个被掩码的结果。

这与安全 Softmax 有类似的问题,即 0/0 = nan,而我们真正想要的是未定义的值。

结论#

在本教程中,我们介绍了 MaskedTensors 是什么,演示了如何使用它们,并通过一系列示例和它们帮助解决的问题来阐述它们的价值。

进一步阅读#

要继续学习更多,您可以查阅我们的 MaskedTensor 稀疏性教程,了解 MaskedTensor 如何实现稀疏性以及我们目前支持的不同存储格式。

# %%%%%%RUNNABLE_CODE_REMOVED%%%%%%

脚本总运行时间:(0 分 0.002 秒)