评价此页

嵌套张量入门#

嵌套张量概括了常规密集张量的形状,允许表示不规则大小的数据。

  • 对于常规张量,每个维度都是规则的并且具有大小

  • 对于嵌套张量,并非所有维度都具有规则大小;其中一些是不规则的

嵌套张量是表示各种领域中顺序数据的自然解决方案

  • 在 NLP 中,句子可以具有可变长度,因此一批句子形成一个嵌套张量

  • 在 CV 中,图像可以具有可变形状,因此一批图像形成一个嵌套张量

在本教程中,我们将演示嵌套张量的基本用法,并通过一个实际示例说明它们在操作不同长度的顺序数据时的有用性。特别是,它们对于构建能够高效操作不规则顺序输入的 Transformer 至关重要。下面,我们展示了一种使用嵌套张量的多头注意力实现,结合使用 torch.compile,其性能优于对带有填充的张量进行朴素操作。

嵌套张量目前是一项原型功能,可能会有所更改。

import numpy as np
import timeit
import torch
import torch.nn.functional as F

from torch import nn

torch.manual_seed(1)
np.random.seed(1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

嵌套张量初始化#

从 Python 前端,可以从张量列表创建嵌套张量。我们将 nt[i] 表示为嵌套张量的第 i 个张量组件。

nt = torch.nested.nested_tensor([torch.arange(12).reshape(
    2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device)
print(f"{nt=}")
/usr/local/lib/python3.10/dist-packages/torch/nested/__init__.py:250: UserWarning:

The PyTorch API of nested tensors is in prototype stage and will change in the near future. We recommend specifying layout=torch.jagged when constructing a nested tensor, as this layout receives active development, has better operator coverage, and works with torch.compile. (Triggered internally at /pytorch/aten/src/ATen/NestedTensorImpl.cpp:178.)

nt=nested_tensor([
  tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
          [ 6.,  7.,  8.,  9., 10., 11.]], device='cuda:0'),
  tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
          [ 6.,  7.,  8.,  9., 10., 11.],
          [12., 13., 14., 15., 16., 17.]], device='cuda:0')
], device='cuda:0')

通过将每个底层张量填充到相同形状,嵌套张量可以转换为常规张量。

padded_out_tensor=tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [ 0.,  0.,  0.,  0.,  0.,  0.]],

        [[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [12., 13., 14., 15., 16., 17.]]], device='cuda:0')

所有张量都具有一个属性,用于确定它们是否是嵌套的;

print(f"nt is nested: {nt.is_nested}")
print(f"padded_out_tensor is nested: {padded_out_tensor.is_nested}")
nt is nested: True
padded_out_tensor is nested: False

通常从批处理的不规则形状张量构造嵌套张量。即,维度 0 被假定为批处理维度。索引维度 0 返回第一个底层张量组件。

print("First underlying tensor component:", nt[0], sep='\n')
print("last column of 2nd underlying tensor component:", nt[1, :, -1], sep='\n')

# When indexing a nestedtensor's 0th dimension, the result is a regular tensor.
print(f"First underlying tensor component is nested: {nt[0].is_nested}")
First underlying tensor component:
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10., 11.]], device='cuda:0')
last column of 2nd underlying tensor component:
tensor([ 5., 11., 17.], device='cuda:0')
First underlying tensor component is nested: False

一个重要注意事项是,目前尚不支持维度 0 中的切片。这意味着目前无法构造结合底层张量组件的视图。

嵌套张量操作#

由于每个操作都必须为嵌套张量明确实现,因此嵌套张量的操作覆盖范围目前比常规张量窄。目前,仅支持索引、dropout、softmax、转置、reshape、线性、bmm 等基本操作。但是,覆盖范围正在扩大。如果您需要某些操作,请提交 问题 以帮助我们确定覆盖范围的优先级。

重塑

重塑操作用于更改张量的形状。其用于常规张量的完整语义可以在此处找到。对于常规张量,在指定新形状时,单个维度可以是 -1,在这种情况下,它会从剩余维度和元素数量中推断出来。

嵌套张量的语义类似,只是 -1 不再推断。相反,它继承旧大小(此处 nt[0] 为 2,nt[1] 为 3)。-1 是为锯齿状维度指定大小的唯一合法大小。

nt_reshaped = nt.reshape(2, -1, 2, 3)
print(f"{nt_reshaped=}")
nt_reshaped=nested_tensor([
  tensor([[[ 0.,  1.,  2.],
           [ 3.,  4.,  5.]],

          [[ 6.,  7.,  8.],
           [ 9., 10., 11.]]], device='cuda:0'),
  tensor([[[ 0.,  1.,  2.],
           [ 3.,  4.,  5.]],

          [[ 6.,  7.,  8.],
           [ 9., 10., 11.]],

          [[12., 13., 14.],
           [15., 16., 17.]]], device='cuda:0')
], device='cuda:0')

转置

转置操作用于交换张量的两个维度。其完整语义可以在此处找到。请注意,对于嵌套张量,维度 0 是特殊的;它被假定为批处理维度,因此不支持涉及嵌套张量维度 0 的转置。

nt_transposed = nt_reshaped.transpose(1, 2)
print(f"{nt_transposed=}")
nt_transposed=nested_tensor([
  tensor([[[ 0.,  1.,  2.],
           [ 6.,  7.,  8.]],

          [[ 3.,  4.,  5.],
           [ 9., 10., 11.]]], device='cuda:0'),
  tensor([[[ 0.,  1.,  2.],
           [ 6.,  7.,  8.],
           [12., 13., 14.]],

          [[ 3.,  4.,  5.],
           [ 9., 10., 11.],
           [15., 16., 17.]]], device='cuda:0')
], device='cuda:0')

其他

其他操作具有与常规张量相同的语义。对嵌套张量应用操作等效于对底层张量组件应用操作,结果也是一个嵌套张量。

nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device)
nt3 = torch.matmul(nt_transposed, nt_mm)
print(f"Result of Matmul:\n {nt3}")

nt4 = F.dropout(nt3, 0.1)
print(f"Result of Dropout:\n {nt4}")

nt5 = F.softmax(nt4, -1)
print(f"Result of Softmax:\n {nt5}")
Result of Matmul:
 nested_tensor([
  tensor([[[  0.7781,   1.7332,   2.5551,  -1.7998],
           [ -6.3416,   0.6039,   3.3571, -21.6835]],

          [[ -3.0563,   1.1609,  -6.8225,  19.4126],
           [ -7.3476,  -0.8315, -15.4485,  44.0489]]], device='cuda:0'),
  tensor([[[ -0.7215,   3.0998,  -0.2846,   4.7335,   3.6254],
           [-17.8239,   9.9335,  14.5221,  25.6358,  15.9261],
           [-34.9263,  16.7672,  29.3289,  46.5381,  28.2268]],

          [[  5.9445,   3.1823,   7.7202, -15.5639,   9.8096],
           [ 13.5947,   9.8521,  19.5695, -38.9003,  20.3403],
           [ 21.2450,  16.5219,  31.4188, -62.2367,  30.8710]]], device='cuda:0')
], device='cuda:0')
Result of Dropout:
 nested_tensor([
  tensor([[[  0.0000,   1.9258,   2.8390,  -1.9998],
           [ -7.0462,   0.6710,   3.7301,  -0.0000]],

          [[ -3.3959,   0.0000,  -0.0000,  21.5696],
           [ -8.1640,  -0.9239, -17.1650,  48.9432]]], device='cuda:0'),
  tensor([[[ -0.8017,   3.4442,  -0.0000,   5.2595,   4.0282],
           [-19.8043,   0.0000,  16.1357,  28.4842,  17.6957],
           [ -0.0000,  18.6302,  32.5877,  51.7090,  31.3631]],

          [[  6.6050,   0.0000,   8.5781, -17.2933,  10.8996],
           [ 15.1053,  10.9468,  21.7439, -43.2226,  22.6003],
           [  0.0000,  18.3577,  34.9098,  -0.0000,  34.3011]]], device='cuda:0')
], device='cuda:0')
Result of Softmax:
 nested_tensor([
  tensor([[[3.9850e-02, 2.7339e-01, 6.8136e-01, 5.3942e-03],
           [1.9504e-05, 4.3819e-02, 9.3376e-01, 2.2400e-02]],

          [[1.4375e-11, 4.2900e-10, 4.2900e-10, 1.0000e+00],
           [1.5800e-25, 2.2030e-22, 1.9480e-29, 1.0000e+00]]], device='cuda:0'),
  tensor([[[1.5946e-03, 1.1133e-01, 3.5548e-03, 6.8387e-01, 1.9964e-01],
           [1.0679e-21, 4.2604e-13, 4.3361e-06, 9.9998e-01, 2.0634e-05],
           [3.4921e-23, 4.3061e-15, 4.9628e-09, 1.0000e+00, 1.4585e-09]],

          [[1.2271e-02, 1.6610e-05, 8.8259e-02, 5.1285e-13, 8.9945e-01],
           [3.8999e-04, 6.0961e-06, 2.9797e-01, 1.8180e-29, 7.0163e-01],
           [4.4690e-16, 4.1963e-08, 6.4764e-01, 4.4690e-16, 3.5236e-01]]],
         device='cuda:0')
], device='cuda:0')

为什么选择嵌套张量#

当数据是顺序的时,通常每个样本具有不同的长度。例如,在一批句子中,每个句子具有不同数量的单词。处理可变序列的常见技术是手动将每个数据张量填充到相同形状以形成批处理。例如,我们有 2 个不同长度和词汇表的句子。为了将它们表示为单个张量,我们用 0 填充到批处理中的最大长度。

sentences = [["goodbye", "padding"],
             ["embrace", "nested", "tensor"]]
vocabulary = {"goodbye": 1.0, "padding": 2.0,
              "embrace": 3.0, "nested": 4.0, "tensor": 5.0}
padded_sentences = torch.tensor([[1.0, 2.0, 0.0],
                                 [3.0, 4.0, 5.0]])
nested_sentences = torch.nested.nested_tensor([torch.tensor([1.0, 2.0]),
                                               torch.tensor([3.0, 4.0, 5.0])])
print(f"{padded_sentences=}")
print(f"{nested_sentences=}")
padded_sentences=tensor([[1., 2., 0.],
        [3., 4., 5.]])
nested_sentences=nested_tensor([
  tensor([1., 2.]),
  tensor([3., 4., 5.])
])

这种将一批数据填充到其最大长度的技术并不是最优的。填充的数据对于计算来说是不必要的,并且通过分配比必要更大的张量来浪费内存。此外,并非所有操作在应用于填充数据时都具有相同的语义。对于矩阵乘法,为了忽略填充条目,需要用 0 填充,而对于 softmax,必须用 -inf 填充以忽略特定条目。嵌套张量的主要目标是使用标准的 PyTorch 张量 UX 促进对不规则数据的操作,从而消除低效和复杂的填充和掩码的需要。

padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")],
                                             [3.0, 4.0, 5.0]])
print(F.softmax(padded_sentences_for_softmax, -1))
print(F.softmax(nested_sentences, -1))
tensor([[0.2689, 0.7311, 0.0000],
        [0.0900, 0.2447, 0.6652]])
nested_tensor([
  tensor([0.2689, 0.7311]),
  tensor([0.0900, 0.2447, 0.6652])
])

让我们看一个实际示例:Transformers 中使用的多头注意力组件。我们可以以一种能够操作填充张量或嵌套张量的方式实现它。

class MultiHeadAttention(nn.Module):
    """
    Computes multi-head attention. Supports nested or padded tensors.

    Args:
        E_q (int): Size of embedding dim for query
        E_k (int): Size of embedding dim for key
        E_v (int): Size of embedding dim for value
        E_total (int): Total embedding dim of combined heads post input projection. Each head
            has dim E_total // nheads
        nheads (int): Number of heads
        dropout_p (float, optional): Dropout probability. Default: 0.0
    """
    def __init__(self, E_q: int, E_k: int, E_v: int, E_total: int,
                 nheads: int, dropout_p: float = 0.0):
        super().__init__()
        self.nheads = nheads
        self.dropout_p = dropout_p
        self.query_proj = nn.Linear(E_q, E_total)
        self.key_proj = nn.Linear(E_k, E_total)
        self.value_proj = nn.Linear(E_v, E_total)
        E_out = E_q
        self.out_proj = nn.Linear(E_total, E_out)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        """
        Forward pass; runs the following process:
            1. Apply input projection
            2. Split heads and prepare for SDPA
            3. Run SDPA
            4. Apply output projection

        Args:
            query (torch.Tensor): query of shape (N, L_t, E_q)
            key (torch.Tensor): key of shape (N, L_s, E_k)
            value (torch.Tensor): value of shape (N, L_s, E_v)

        Returns:
            attn_output (torch.Tensor): output of shape (N, L_t, E_q)
        """
        # Step 1. Apply input projection
        # TODO: demonstrate packed projection
        query = self.query_proj(query)
        key = self.key_proj(key)
        value = self.value_proj(value)

        # Step 2. Split heads and prepare for SDPA
        # reshape query, key, value to separate by head
        # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
        query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)

        # Step 3. Run SDPA
        # (N, nheads, L_t, E_head)
        attn_output = F.scaled_dot_product_attention(
            query, key, value, dropout_p=dropout_p, is_causal=True)
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        # Step 4. Apply output projection
        # (N, L_t, E_total) -> (N, L_t, E_out)
        attn_output = self.out_proj(attn_output)

        return attn_output

按照Transformer 论文设置超参数

N = 512
E_q, E_k, E_v, E_total = 512, 512, 512, 512
E_out = E_q
nheads = 8

除了 dropout 概率:为正确性检查设置为 0

dropout_p = 0.0

让我们从齐普夫定律生成一些真实的伪数据。

def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return torch.tensor(sentence_lengths)

创建嵌套张量批处理输入

def gen_batch(N, E_q, E_k, E_v, device):
    # generate semi-realistic data using Zipf distribution for sentence lengths
    sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)

    # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
    # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
    # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
    query = torch.nested.nested_tensor([
        torch.randn(l.item(), E_q, device=device)
        for l in sentence_lengths
    ], layout=torch.jagged)

    key = torch.nested.nested_tensor([
        torch.randn(s.item(), E_k, device=device)
        for s in sentence_lengths
    ], layout=torch.jagged)

    value = torch.nested.nested_tensor([
        torch.randn(s.item(), E_v, device=device)
        for s in sentence_lengths
    ], layout=torch.jagged)

    return query, key, value, sentence_lengths

query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)

生成查询、键、值的填充形式进行比较

def jagged_to_padded(jt, padding_val):
    # TODO: do jagged -> padded directly when this is supported
    return torch.nested.to_padded_tensor(
        torch.nested.nested_tensor(list(jt.unbind())),
        padding_val)

padded_query, padded_key, padded_value = (
    jagged_to_padded(t, 0.0) for t in (query, key, value)
)

构建模型

mha = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads, dropout_p).to(device=device)

检查正确性和性能

def benchmark(func, *args, **kwargs):
    torch.cuda.synchronize()
    begin = timeit.default_timer()
    output = func(*args, **kwargs)
    torch.cuda.synchronize()
    end = timeit.default_timer()
    return output, (end - begin)

output_nested, time_nested = benchmark(mha, query, key, value)
output_padded, time_padded = benchmark(mha, padded_query, padded_key, padded_value)

# padding-specific step: remove output projection bias from padded entries for fair comparison
for i, entry_length in enumerate(sentence_lengths):
    output_padded[i, entry_length:] = 0.0

print("=== without torch.compile ===")
print("nested and padded calculations differ by", (jagged_to_padded(output_nested, 0.0) - output_padded).abs().max().item())
print("nested tensor multi-head attention takes", time_nested, "seconds")
print("padded tensor multi-head attention takes", time_padded, "seconds")

# warm up compile first...
compiled_mha = torch.compile(mha)
compiled_mha(query, key, value)
# ...now benchmark
compiled_output_nested, compiled_time_nested = benchmark(
    compiled_mha, query, key, value)

# warm up compile first...
compiled_mha(padded_query, padded_key, padded_value)
# ...now benchmark
compiled_output_padded, compiled_time_padded = benchmark(
    compiled_mha, padded_query, padded_key, padded_value)

# padding-specific step: remove output projection bias from padded entries for fair comparison
for i, entry_length in enumerate(sentence_lengths):
    compiled_output_padded[i, entry_length:] = 0.0

print("=== with torch.compile ===")
print("nested and padded calculations differ by", (jagged_to_padded(compiled_output_nested, 0.0) - compiled_output_padded).abs().max().item())
print("nested tensor multi-head attention takes", compiled_time_nested, "seconds")
print("padded tensor multi-head attention takes", compiled_time_padded, "seconds")
=== without torch.compile ===
nested and padded calculations differ by 0.0
nested tensor multi-head attention takes 0.012142510999865408 seconds
padded tensor multi-head attention takes 0.00962024599994038 seconds
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:282: UserWarning:

TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.

=== with torch.compile ===
nested and padded calculations differ by 0.0
nested tensor multi-head attention takes 0.0028233499999714695 seconds
padded tensor multi-head attention takes 0.009543115000042235 seconds

请注意,如果没有 torch.compile,Python 子类嵌套张量的开销可能会使其比填充张量上的等效计算慢。但是,一旦启用 torch.compile,对嵌套张量进行操作会提供多倍的速度提升。随着批处理中填充百分比的增加,避免浪费计算在填充上变得更有价值。

print(f"Nested speedup: {compiled_time_padded / compiled_time_nested:.3f}")
Nested speedup: 3.380

结论#

在本教程中,我们学习了如何使用嵌套张量执行基本操作,以及如何以避免在填充上进行计算的方式为 Transformer 实现多头注意力。有关更多信息,请查看 torch.nested 命名空间的文档。

另请参阅#

脚本总运行时间: (0 分 6.347 秒)