评价此页

MaskedTensor 高级语义#

在学习本教程之前,请务必回顾我们的 MaskedTensor 概览教程 <https://pytorch.ac.cn/tutorials/prototype/maskedtensor_overview.html>

本教程的目的是帮助用户理解一些高级语义的工作方式以及它们的由来。我们将重点关注其中两个

*。MaskedTensor 与 NumPy 的 MaskedArray 之间的差异 *。规约语义

准备工作#

# Disable prototype warnings and such

MaskedTensor vs NumPy 的 MaskedArray#

NumPy 的 MaskedArray 在一些基本语义上与 MaskedTensor 存在差异。

*。它们的工厂函数和基本定义反转了掩码(类似于 torch.nn.MHA);也就是说,MaskedTensor

使用 True 表示“已指定”和 False 表示“未指定”,或“有效”/“无效”,而 NumPy 则相反。我们认为我们的掩码定义不仅更直观,而且更符合 PyTorch 整体上已有的语义。

*。交集语义。在 NumPy 中,如果两个元素中的一个被掩盖掉,则结果元素将被

掩盖掉——实际上,它们 应用了 logical_or 运算符

与此同时,MaskedTensor 不支持具有不匹配掩码的加法或二元运算符——要了解原因,请参阅 规约部分

然而,如果需要这种行为,MaskedTensor 通过访问数据和掩码,并使用 to_tensor() 便利地将 MaskedTensor 转换为带有已填充掩码值的 Tensor 来支持这些语义。例如

请注意,掩码是 mt0.get_mask() & mt1.get_mask(),因为 MaskedTensor 的掩码与 NumPy 的掩码是相反的。

规约语义#

回想在 MaskedTensor 的概览教程 中,我们讨论了“实现缺失的 torch.nan* 算子”。这些就是规约的例子——会移除张量的一个(或多个)维度然后聚合结果的算子。在本节中,我们将使用规约语义来解释我们对上方匹配掩码的严格要求。

从根本上说,:class:`MaskedTensor` 执行相同的规约操作,同时忽略被掩盖掉(未指定)的值。举例来说

现在,不同的规约(都在 dim=1 上)

值得注意的是,被掩盖元素下的值不保证具有任何特定值,特别是如果该行或列完全被掩盖(归一化也是如此)。有关掩码语义的更多详细信息,请参阅此 RFC

现在,我们可以重新审视这个问题:为什么我们要强制执行二元运算符的掩码必须匹配的不变量?换句话说,为什么我们不使用与 np.ma.masked_array 相同的语义?请考虑以下示例

现在,我们来尝试加法

求和与加法应该是满足结合律的,但使用 NumPy 的语义,它们却不满足,这无疑会让用户感到困惑。

MaskedTensor,另一方面,由于 mask0 != mask1,将直接不允许此操作。也就是说,如果用户愿意,也有绕过此限制的方法(例如,使用 to_tensor() 将 MaskedTensor 的未定义元素填充为 0 值,如下所示),但用户现在必须更明确自己的意图。

结论#

在本教程中,我们学习了 MaskedTensor 和 NumPy 的 MaskedArray 背后的不同设计决策,以及规约语义。总的来说,MaskedTensor 的设计旨在避免歧义和令人困惑的语义(例如,我们试图在二元运算之间保持结合律),这有时可能需要用户在编写代码时更加有意识,但我们认为这是更好的选择。如果您对此有任何想法,请 告诉我们

# %%%%%%RUNNABLE_CODE_REMOVED%%%%%%

脚本总运行时间:(0 分 0.002 秒)