评价此页

嵌套张量入门#

嵌套张量将常规密集张量的形状泛化,允许表示不规则大小的数据。

  • 对于常规张量,每个维度都是规则的,并且具有一个大小

  • 对于嵌套张量,并非所有维度都具有规则的大小;其中一些是参差不齐的

嵌套张量是表示各种领域中顺序数据的自然解决方案

  • 在 NLP 中,句子的长度可变,因此句子的批次会形成一个嵌套张量

  • 在 CV 中,图像的形状可变,因此图像的批次会形成一个嵌套张量

在本教程中,我们将演示嵌套张量的基本用法,并通过实际示例说明其在处理可变长度序列数据时的实用性。特别是,它们对于构建能够高效处理参差不齐序列输入的 Transformer 至关重要。下面,我们使用嵌套张量实现多头注意力,结合使用 torch.compile,其性能优于在带有填充的张量上进行朴素操作。

嵌套张量目前是一项原型功能,可能会发生变化。

嵌套张量初始化#

从 Python 前端,可以从张量列表创建嵌套张量。我们用 nt[i] 表示嵌套张量的第 i 个张量分量。

通过将每个底层张量填充到相同的形状,可以将嵌套张量转换为常规张量。

所有张量都有一个用于确定它们是否为嵌套张量的属性;

通常从不规则形状张量的批次构建嵌套张量。即,假定维度 0 是批次维度。索引维度 0 会返回第一个底层张量分量。

# When indexing a nestedtensor's 0th dimension, the result is a regular tensor.

一个重要的注意事项是,在维度 0 上进行切片尚未得到支持。这意味着目前无法构建一个结合底层张量分量的视图。

嵌套张量操作#

由于每个操作都必须为嵌套张量显式实现,因此嵌套张量的操作覆盖范围目前比常规张量窄。目前,仅涵盖基本操作,如索引、dropout、softmax、transpose、reshape、linear、bmm。但是,覆盖范围正在扩大。如果您需要某些操作,请提交一个issue 来帮助我们确定覆盖范围的优先级。

reshape

reshape 操作用于更改张量的形状。其对于常规张量的完整语义可在此处找到。对于常规张量,在指定新形状时,一个维度可以为 -1,在这种情况下,它会根据剩余维度和元素数量进行推断。

嵌套张量的语义类似,只是 -1 不再进行推断。相反,它继承旧的大小(此处为 nt[0] 的 2 和 nt[1] 的 3)。-1 是为参差不齐维度指定的唯一合法大小。

转置

transpose 操作用于交换张量的两个维度。其完整语义可在此处找到。请注意,对于嵌套张量,维度 0 是特殊的;它被假定为批次维度,因此不支持涉及嵌套张量维度 0 的转置。

其他

其他操作与常规张量的语义相同。将操作应用于嵌套张量等同于将操作应用于底层张量分量,结果也是一个嵌套张量。

为什么使用嵌套张量#

当数据是序列化的,通常每个样本的长度都不同。例如,在句子的批次中,每个句子的单词数量都不同。处理可变长度序列的一种常用技术是手动将每个数据张量填充到相同的形状,以形成批次。例如,我们有两个不同长度的句子和一个词汇表,为了将它们表示为单个张量,我们用 0 填充到批次中的最大长度。

将数据批次填充到最大长度的这种技术并非最优。填充的数据对于计算不是必需的,并且会通过分配比所需更大的张量来浪费内存。此外,并非所有操作在应用于填充数据时都具有相同的语义。对于矩阵乘法,为了忽略填充的条目,需要用 0 填充,而对于 softmax,则必须用 -inf 填充以忽略特定条目。嵌套张量的主要目标是使用标准的 PyTorch 张量用户体验来促进对参差不齐数据的操作,从而消除对低效且复杂的填充和掩码的需要。

让我们看一个实际示例:Transformer 中使用的多头注意力组件。我们可以以这样一种方式实现它,使其能够操作填充或嵌套的张量。

设置超参数,遵循Transformer 论文

除了 dropout 概率:设置为 0 以进行正确性检查

让我们根据 Zipf 定律生成一些真实的假数据。

创建嵌套张量批次输入

生成查询、键、值用于比较的填充形式

构建模型

检查正确性和性能

# padding-specific step: remove output projection bias from padded entries for fair comparison








# warm up compile first...


# ...now benchmark



# warm up compile first...

# ...now benchmark



# padding-specific step: remove output projection bias from padded entries for fair comparison

请注意,如果没有 torch.compile,Python 子类嵌套张量的开销可能会使其比等效的填充张量计算慢。但是,一旦启用了 torch.compile,操作嵌套张量可以实现多倍加速。随着批次中填充的百分比增加,避免对填充进行不必要的计算变得更加有价值。

结论#

在本教程中,我们学习了如何执行嵌套张量的基本操作以及如何实现 Transformer 的多头注意力,以避免对填充进行计算。有关更多信息,请查看torch.nested 命名空间文档。

另请参阅#

# %%%%%%RUNNABLE_CODE_REMOVED%%%%%%

脚本总运行时间:(0 分 0.002 秒)