注意
转到末尾 下载完整的示例代码。
MaskedTensor 高级语义#
在开始本教程之前,请务必阅读我们的 MaskedTensor 概述教程 <https://pytorch.ac.cn/tutorials/prototype/maskedtensor_overview.html>。
本教程的目的是帮助用户理解一些高级语义的工作原理以及它们的由来。我们将重点关注两个特定的语义:
*. MaskedTensor 与 NumPy 的 MaskedArray 之间的差异 *. 规约语义
准备工作#
# Disable prototype warnings and such
MaskedTensor vs NumPy 的 MaskedArray#
NumPy 的 MaskedArray
与 MaskedTensor 在一些基本语义上存在差异。
- *. 它们的工厂函数和基本定义反转了掩码(类似于
torch.nn.MHA
);也就是说,MaskedTensor 使用
True
表示“已指定”,False
表示“未指定”或“有效”/“无效”,而 NumPy 则相反。我们认为我们的掩码定义不仅更直观,而且与 PyTorch 整体的现有语义也更一致。- *. 交集语义。在 NumPy 中,如果两个元素中的一个被掩盖掉,则结果元素也将被
掩盖掉——实际上,它们应用逻辑或运算符。
同时,MaskedTensor 不支持与不匹配的掩码进行加法或二进制运算——要理解原因,请查阅关于规约的部分。
然而,如果需要这种行为,MaskedTensor 通过提供对数据和掩码的访问,并使用 to_tensor()
方便地将 MaskedTensor 转换为填充了掩码值的张量,从而支持这些语义。例如
请注意,掩码是 mt0.get_mask() & mt1.get_mask(),因为 MaskedTensor
的掩码与 NumPy 的掩码是相反的。
规约语义#
回想在MaskedTensor 概述教程中,我们讨论了“实现缺失的 torch.nan* 操作”。这些是规约的例子——从张量中移除一个(或多个)维度,然后聚合结果的运算符。在本节中,我们将使用规约语义来解释我们对上方匹配掩码的严格要求。
从根本上说,:class:`MaskedTensor` 执行相同的规约操作,同时忽略被掩盖(未指定)的值。举个例子:
现在,不同的规约(都在 dim=1 上)
值得注意的是,被遮盖元素下的值不保证有任何特定值,特别是当行或列完全被遮盖时(标准化也如此)。有关遮盖语义的更多详细信息,您可以查阅此 RFC。
现在,我们可以重新审视这个问题:为什么我们强制要求二进制运算符的掩码必须匹配?换句话说,为什么我们不使用与 np.ma.masked_array
相同的语义?考虑以下示例
现在,让我们尝试加法
求和和加法显然应该满足结合律,但使用 NumPy 的语义,它们不满足,这肯定会让用户感到困惑。
MaskedTensor
则根本不允许此操作,因为 mask0 != mask1。话虽如此,如果用户需要,也有办法绕过此限制(例如,使用 to_tensor()
将 MaskedTensor 的未定义元素填充为 0 值,如下所示),但此时用户必须更明确地表达其意图。
结论#
在本教程中,我们学习了 MaskedTensor 和 NumPy 的 MaskedArray 背后的不同设计决策,以及规约语义。总的来说,MaskedTensor 旨在避免歧义和令人困惑的语义(例如,我们尝试在二进制操作中保留结合律),这反过来可能要求用户在某些时候对代码更具意图性,但我们认为这是更好的举动。如果您对此有任何想法,请告诉我们!
# %%%%%%RUNNABLE_CODE_REMOVED%%%%%%
脚本总运行时间:(0 分 0.002 秒)