torch.nested#
创建日期:2022 年 3 月 2 日 | 最后更新日期:2025 年 6 月 14 日
引言#
警告
PyTorch 嵌套张量 API 仍处于原型阶段,近期将有所变化。
嵌套张量允许将不规则形状的数据包含在一个张量中并对其进行操作。此类数据以高效的打包表示形式存储在底层,同时暴露标准 PyTorch 张量接口以应用操作。
嵌套张量的一个常见应用是表达各种领域中可变长度序列数据的批处理,例如不同的句子长度、图像大小和音频/视频剪辑长度。传统上,此类数据是通过将序列填充到批处理中的最大长度,对填充形式执行计算,然后进行掩码以移除填充来处理的。这既低效又容易出错,嵌套张量的存在就是为了解决这些问题。
对嵌套张量执行操作的 API 与常规 torch.Tensor
的 API 没有区别,允许与现有模型无缝集成,主要区别在于输入的构造。
由于这是一个原型功能,支持的操作集有限,但正在不断增长。我们欢迎问题、功能请求和贡献。有关贡献的更多信息可以在此自述文件中找到。
构造#
注意
PyTorch 中存在两种形式的嵌套张量,通过构造期间指定的布局来区分。布局可以是 torch.strided
或 torch.jagged
之一。我们建议尽可能使用 torch.jagged
布局。虽然它目前只支持一个不规则维度,但它具有更好的操作覆盖率,正在积极开发中,并且与 torch.compile
很好地集成。这些文档遵循此建议,为简洁起见,将具有 torch.jagged
布局的嵌套张量称为“NJT”。
构造很简单,只需将张量列表传递给 torch.nested.nested_tensor
构造函数即可。具有 torch.jagged
布局(又名“NJT”)的嵌套张量支持单个不规则维度。此构造函数将根据下面 data_layout
_ 部分中描述的布局,将输入张量复制到打包的连续内存块中。
>>> a, b = torch.arange(3), torch.arange(5) + 3
>>> a
tensor([0, 1, 2])
>>> b
tensor([3, 4, 5, 6, 7])
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> print([component for component in nt])
[tensor([0, 1, 2]), tensor([3, 4, 5, 6, 7])]
列表中的每个张量必须具有相同的维度数,但形状可以在单个维度上有所不同。如果输入组件的维度不匹配,则构造函数会抛出错误。
>>> a = torch.randn(50, 128) # 2D tensor
>>> b = torch.randn(2, 50, 128) # 3D tensor
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
...
RuntimeError: When constructing a nested tensor, all tensors in list must have the same dim
在构造过程中,可以通过常用关键字参数选择 dtype、device 以及是否需要梯度。
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32, device="cuda", requires_grad=True)
>>> print([component for component in nt])
[tensor([0., 1., 2.], device='cuda:0',
grad_fn=<UnbindBackwardAutogradNestedTensor0>), tensor([3., 4., 5., 6., 7.], device='cuda:0',
grad_fn=<UnbindBackwardAutogradNestedTensor0>)]
torch.nested.as_nested_tensor
可用于保留从传递给构造函数的张量中获得的 autograd 历史记录。当使用此构造函数时,梯度将通过嵌套张量流回原始组件。请注意,此构造函数仍会将输入组件复制到打包的连续内存块中。
>>> a = torch.randn(12, 512, requires_grad=True)
>>> b = torch.randn(23, 512, requires_grad=True)
>>> nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.sum().backward()
>>> a.grad
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]])
>>> b.grad
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]])
上述所有函数都创建了连续的 NJT,其中分配了一块内存来存储底层组件的打包形式(有关详细信息,请参阅下面的 data_layout
_ 部分)。
也可以使用 torch.nested.narrow()
在预先存在的带填充的密集张量上创建非连续 NJT 视图,从而避免内存分配和复制。
>>> padded = torch.randn(3, 5, 4)
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
>>> nt = torch.nested.narrow(padded, dim=1, start=0, length=seq_lens, layout=torch.jagged)
>>> nt.shape
torch.Size([3, j1, 4])
>>> nt.is_contiguous()
False
请注意,嵌套张量充当原始填充密集张量的视图,引用相同的内存而无需复制/分配。对非连续 NJT 的操作支持有所限制,因此如果您遇到支持空白,始终可以使用 contiguous()
转换为连续 NJT。
数据布局和形状#
为了提高效率,嵌套张量通常将其张量组件打包成连续的内存块,并维护额外的元数据以指定批处理项边界。对于 torch.jagged
布局,连续内存块存储在 values
组件中,而 offsets
组件则划分不规则维度的批处理项边界。
如有必要,可以直接访问底层 NJT 组件。
>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(32, 128) # text 2
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.values().shape # note the "packing" of the ragged dimension; no padding needed
torch.Size([82, 128])
>>> nt.offsets()
tensor([ 0, 50, 82])
直接从锯齿状 values
和 offsets
成分构造 NJT 也很有用;torch.nested.nested_tensor_from_jagged()
构造函数就是为了这个目的。
>>> values = torch.randn(82, 128)
>>> offsets = torch.tensor([0, 50, 82], dtype=torch.int64)
>>> nt = torch.nested.nested_tensor_from_jagged(values=values, offsets=offsets)
NJT 具有定义明确的形状,其维度比其组件的维度大 1。不规则维度的底层结构由符号值(下例中的 j1
)表示。
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.dim()
3
>>> nt.shape
torch.Size([2, j1, 128])
NJT 必须具有相同的不规则结构才能相互兼容。例如,要执行涉及两个 NJT 的二元操作,不规则结构必须匹配(即,它们的形状中必须具有相同的不规则形状符号)。具体来说,每个符号对应一个精确的 offsets
张量,因此两个 NJT 必须具有相同的 offsets
张量才能相互兼容。
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt2 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt1.offsets() is nt2.offsets()
False
>>> nt3 = nt1 + nt2
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
在上述示例中,尽管两个 NJT 的概念形状相同,但它们不共享对同一 offsets
张量的引用,因此它们的形状不同,并且不兼容。我们认识到这种行为是不直观的,并且正在努力在嵌套张量的 beta 版本中放宽此限制。有关解决方法,请参阅本文档的故障排除部分。
除了 offsets
元数据之外,NJT 还可以计算和缓存其组件的最小和最大序列长度,这对于调用特定内核(例如 SDPA)可能很有用。目前没有公开的 API 来访问这些,但这将在 beta 版本中改变。
支持的操作#
本节包含您可能会觉得有用的嵌套张量常见操作列表。它并不全面,因为 PyTorch 中有大约几千个操作。虽然今天嵌套张量支持其中相当一部分,但全面支持是一项艰巨的任务。嵌套张量的理想状态是全面支持非嵌套张量可用的所有 PyTorch 操作。为了帮助我们实现这一目标,请考虑
在此请求您的用例所需的特定操作,以帮助我们确定优先级。
贡献!为给定的 PyTorch 操作添加嵌套张量支持并不太难;有关详细信息,请参阅下面的贡献部分。
查看嵌套张量成分#
unbind()
允许您检索嵌套张量成分的视图。
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt.unbind()
(tensor([[-0.9916, -0.3363, -0.2799],
[-2.3520, -0.5896, -0.4374]]), tensor([[-2.0969, -1.0104, 1.4841],
[ 2.0952, 0.2973, 0.2516],
[ 0.9035, 1.3623, 0.2026]]))
>>> nt.unbind()[0] is not a
True
>>> nt.unbind()[0].mul_(3)
tensor([[ 3.6858, -3.7030, -4.4525],
[-2.3481, 2.0236, 0.1975]])
>>> nt.unbind()
(tensor([[-2.9747, -1.0089, -0.8396],
[-7.0561, -1.7688, -1.3122]]), tensor([[-2.0969, -1.0104, 1.4841],
[ 2.0952, 0.2973, 0.2516],
[ 0.9035, 1.3623, 0.2026]]))
请注意,nt.unbind()[0]
不是副本,而是底层内存的一个切片,它代表嵌套张量的第一个条目或成分。
与填充之间的转换#
torch.nested.to_padded_tensor()
将 NJT 转换为具有指定填充值的带填充密集张量。不规则维度将被填充到最大序列长度的大小。
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(6, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> padded = torch.nested.to_padded_tensor(nt, padding=4.2)
>>> padded
tensor([[[ 1.6107, 0.5723, 0.3913],
[ 0.0700, -0.4954, 1.8663],
[ 4.2000, 4.2000, 4.2000],
[ 4.2000, 4.2000, 4.2000],
[ 4.2000, 4.2000, 4.2000],
[ 4.2000, 4.2000, 4.2000]],
[[-0.0479, -0.7610, -0.3484],
[ 1.1345, 1.0556, 0.3634],
[-1.7122, -0.5921, 0.0540],
[-0.5506, 0.7608, 2.0606],
[ 1.5658, -1.1934, 0.3041],
[ 0.1483, -1.1284, 0.6957]]])
这可以作为一种变通方法,以解决 NJT 支持的不足,但理想情况下,应尽可能避免此类转换,以实现最佳内存使用和性能,因为更高效的嵌套张量布局不会具体化填充。
可以使用 torch.nested.narrow()
完成反向转换,该函数将不规则结构应用于给定密集张量以生成 NJT。请注意,默认情况下,此操作不会复制底层数据,因此输出 NJT 通常是不连续的。如果需要连续的 NJT,在此处显式调用 contiguous()
可能会很有用。
>>> padded = torch.randn(3, 5, 4)
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
>>> nt = torch.nested.narrow(padded, dim=1, length=seq_lens, layout=torch.jagged)
>>> nt.shape
torch.Size([3, j1, 4])
>>> nt = nt.contiguous()
>>> nt.shape
torch.Size([3, j2, 4])
形状操作#
嵌套张量支持多种形状操作,包括视图。
>>> a = torch.randn(2, 6)
>>> b = torch.randn(4, 6)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt.shape
torch.Size([2, j1, 6])
>>> nt.unsqueeze(-1).shape
torch.Size([2, j1, 6, 1])
>>> nt.unflatten(-1, [2, 3]).shape
torch.Size([2, j1, 2, 3])
>>> torch.cat([nt, nt], dim=2).shape
torch.Size([2, j1, 12])
>>> torch.stack([nt, nt], dim=2).shape
torch.Size([2, j1, 2, 6])
>>> nt.transpose(-1, -2).shape
torch.Size([2, 6, j1])
注意力机制#
由于变长序列是注意力机制的常见输入,嵌套张量支持重要的注意力运算符 Scaled Dot Product Attention (SDPA) 和 FlexAttention。有关 NJT 与 SDPA 的用法示例,请参见此处;有关 NJT 与 FlexAttention 的用法示例,请参见此处。
与 torch.compile 的用法#
NJT 旨在与 torch.compile()
配合使用以实现最佳性能,我们始终建议尽可能将 torch.compile()
与 NJT 结合使用。NJT 可以开箱即用,并且在作为输入传递给编译函数或模块时,或者在函数内部内联实例化时,都不会出现图中断。
注意
If you're not able to utilize ``torch.compile()`` for your use case, performance and memory
usage may still benefit from the use of NJTs, but it's not as clear-cut whether this will be
the case. It is important that the tensors being operated on are large enough so the
performance gains are not outweighed by the overhead of python tensor subclasses.
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(4, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> def f(x): return x.sin() + 1
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output = compiled_f(nt)
>>> output.shape
torch.Size([2, j1, 3])
>>> def g(values, offsets): return torch.nested.nested_tensor_from_jagged(values, offsets) * 2.
...
>>> compiled_g = torch.compile(g, fullgraph=True)
>>> output2 = compiled_g(nt.values(), nt.offsets())
>>> output2.shape
torch.Size([2, j1, 3])
请注意,NJT 支持动态形状,以避免因不规则结构变化而导致不必要的重新编译。
>>> a = torch.randn(2, 3)
>>> b = torch.randn(4, 3)
>>> c = torch.randn(5, 3)
>>> d = torch.randn(6, 3)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt2 = torch.nested.nested_tensor([c, d], layout=torch.jagged)
>>> def f(x): return x.sin() + 1
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output1 = compiled_f(nt1)
>>> output2 = compiled_f(nt2) # NB: No recompile needed even though ragged structure differs
如果您在使用 NJT + torch.compile
时遇到问题或难以理解的错误,请提交 PyTorch 问题。在 torch.compile
中完全支持子类是一项长期努力,目前可能存在一些不足之处。
故障排除#
本节包含在使用嵌套张量时可能遇到的常见错误,以及这些错误的原因和解决它们的建议。
未实现的运算#
随着嵌套张量运算支持的增长,此类错误正变得越来越少见,但考虑到 PyTorch 中有数千个运算,今天仍然可能遇到它。
NotImplementedError: aten.view_as_real.default
这个错误很简单;我们还没有为这个特定的运算添加运算支持。如果您愿意,可以自行贡献实现,或者简单地请求我们在未来的 PyTorch 版本中添加对该运算的支持。
不规则结构不兼容#
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
当调用对多个 NJT 进行操作且其不规则结构不兼容的运算符时,会发生此错误。目前,要求输入 NJT 具有完全相同的 offsets
组成部分,才能具有相同的符号不规则结构符号(例如 j1
)。
作为这种情况的变通方法,可以直接从 values
和 offsets
组件构建 NJT。如果两个 NJT 都引用相同的 offsets
组件,则它们被认为具有相同的锯齿状结构,因此兼容。
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt2 = torch.nested.nested_tensor_from_jagged(values=torch.randn(82, 128), offsets=nt1.offsets())
>>> nt3 = nt1 + nt2
>>> nt3.shape
torch.Size([2, j1, 128])
torch.compile 中的数据依赖操作#
torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
当在 torch.compile 中调用执行数据依赖操作的运算符时,会发生此错误;这通常发生在需要检查 NJT 的 offsets
值以确定输出形状的运算符上。例如:
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> def f(nt): return nt.chunk(2, dim=0)[0]
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output = compiled_f(nt)
在此示例中,对 NJT 的批处理维度调用 chunk()
需要检查 NJT 的 offsets
数据,以区分打包不规则维度中的批处理项边界。作为一种解决方法,可以设置几个 torch.compile 标志:
>>> torch._dynamo.config.capture_dynamic_output_shape_ops = True
>>> torch._dynamo.config.capture_scalar_outputs = True
如果设置这些参数后仍然出现数据依赖运算符错误,请向 PyTorch 提交问题。 torch.compile()
的这一领域仍在大力开发中,NJT 支持的某些方面可能尚不完整。
贡献#
如果您想为嵌套张量开发做出贡献,最有效的方法之一是为当前不支持的 PyTorch 操作添加嵌套张量支持。此过程通常包括几个简单步骤:
确定要添加的操作的名称;这应该是类似于
aten.view_as_real.default
的内容。此操作的签名可以在aten/src/ATen/native/native_functions.yaml
中找到。在
torch/nested/_internal/ops.py
中注册一个操作实现,遵循为其他操作建立的模式。使用native_functions.yaml
中的签名进行模式验证。
实现一个操作最常见的方法是将 NJT 解包为其组成部分,在底层 values
缓冲区上重新调度该操作,并将相关的 NJT 元数据(包括 offsets
)传播到新的输出 NJT。如果操作的输出预计与输入具有不同的形状,则必须计算新的 offsets
等元数据。
当操作应用于批次或不规则维度时,这些技巧可以帮助快速获得工作实现:
对于*非批次*操作,基于
unbind()
的回退应该有效。对于不规则维度上的操作,请考虑转换为带填充的密集张量,并选择一个不会对输出产生负面偏差的适当填充值,然后运行操作,再转换回 NJT。在
torch.compile
中,这些转换可以融合,以避免具体化带填充的中间结果。
构造和转换函数的详细文档#
- torch.nested.nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False)[source]#
从张量列表
tensor_list
构造一个没有自动梯度历史(也称为“叶张量”,参见自动梯度机制)的嵌套张量。- 参数
tensor_list (List[array_like]) – 张量列表,或任何可以传递给 torch.tensor 的内容,
dimensionality. (其中列表的每个元素具有相同的维度。) –
- 关键字参数
dtype (
torch.dtype
, optional) – 返回嵌套张量的期望类型。默认值:如果为 None,则与列表中最左侧张量的torch.dtype
相同。layout (
torch.layout
, optional) – 返回嵌套张量的期望布局。仅支持 stride 和 jagged 布局。默认值:如果为 None,则为 stride 布局。device (
torch.device
, optional) – 返回嵌套张量的期望设备。默认值:如果为 None,则与列表中最左侧张量的torch.device
相同requires_grad (bool, optional) – 如果自动梯度应记录返回嵌套张量上的操作。默认值:
False
。pin_memory (bool, optional) – 如果设置,返回的嵌套张量将分配在锁定内存中。仅适用于 CPU 张量。默认值:
False
。
- 返回类型
示例
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True) >>> b = torch.arange(5, dtype=torch.float, requires_grad=True) >>> nt = torch.nested.nested_tensor([a, b], requires_grad=True) >>> nt.is_leaf True
- torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None, jagged_dim=None, min_seqlen=None, max_seqlen=None)[source]#
根据给定的锯齿状组件构造一个锯齿状布局嵌套张量。锯齿状布局包含一个必需的值缓冲区,其中锯齿状维度被打包到一个维度中。偏移量/长度元数据决定了该维度如何拆分为批处理元素,并且预计与值缓冲区分配在同一设备上。
- 预期元数据格式
偏移量:打包维度内的索引,将其拆分为异构大小的批处理元素。示例:[0, 2, 3, 6] 表示大小为 6 的打包锯齿状维度应概念上拆分为长度为 [2, 1, 3] 的批处理元素。请注意,为了内核方便,需要起始和结束偏移量(即形状 batch_size + 1)。
长度:各个批次元素的长度;形状 == batch_size。示例:[2, 1, 3] 表示大小为 6 的打包锯齿维度应在概念上拆分为长度为 [2, 1, 3] 的批次元素。
请注意,同时提供偏移量和长度可能很有用。这描述了一个带有“空洞”的嵌套张量,其中偏移量表示每个批处理项的起始位置,长度指定元素的总数(参见下面的示例)。
返回的锯齿布局嵌套张量将是输入值张量的一个视图。
- 参数
values (
torch.Tensor
) – 底层缓冲区,形状为 (sum_B(*), D_1, …, D_N)。锯齿维度被打包成一个单一维度,并使用偏移量/长度元数据来区分批处理元素。offsets (optional
torch.Tensor
) – 形状为 B + 1 的锯齿维度偏移量。lengths (optional
torch.Tensor
) – 形状为 B 的批处理元素的长度。jagged_dim (optional python:int) – 表示值中哪个维度是打包的锯齿状维度。如果为 None,则将其设置为 dim=1(即紧跟批次维度后的维度)。默认值:None
min_seqlen (optional python:int) – 如果设置,则使用指定值作为返回嵌套张量的缓存最小序列长度。这可以替代按需计算此值,可能避免 GPU -> CPU 同步。默认值:None
max_seqlen (optional python:int) – 如果设置,则使用指定值作为返回嵌套张量的缓存最大序列长度。这可以作为按需计算此值的有用替代方案,可能避免 GPU -> CPU 同步。默认值:None
- 返回类型
示例
>>> values = torch.randn(12, 5) >>> offsets = torch.tensor([0, 3, 5, 6, 10, 12]) >>> nt = nested_tensor_from_jagged(values, offsets) >>> # 3D shape with the middle dimension jagged >>> nt.shape torch.Size([5, j2, 5]) >>> # Length of each item in the batch: >>> offsets.diff() tensor([3, 2, 1, 4, 2]) >>> values = torch.randn(6, 5) >>> offsets = torch.tensor([0, 2, 3, 6]) >>> lengths = torch.tensor([1, 1, 2]) >>> # NT with holes >>> nt = nested_tensor_from_jagged(values, offsets, lengths) >>> a, b, c = nt.unbind() >>> # Batch item 1 consists of indices [0, 1) >>> torch.equal(a, values[0:1, :]) True >>> # Batch item 2 consists of indices [2, 3) >>> torch.equal(b, values[2:3, :]) True >>> # Batch item 3 consists of indices [3, 5) >>> torch.equal(c, values[3:5, :]) True
- torch.nested.as_nested_tensor(ts, dtype=None, device=None, layout=None)[source]#
从张量或张量列表/元组构造一个保留自动梯度历史的嵌套张量。
如果传入的是嵌套张量,则除非设备/数据类型/布局不同,否则将直接返回。请注意,转换设备/数据类型会导致复制,而此函数目前不支持转换布局。
如果传入的是非嵌套张量,则将其视为批处理的统一大小的组成部分。如果传入的设备/数据类型与输入的不同,或者输入不连续,则会发生复制。否则,将直接使用输入的存储。
如果提供了张量列表,则在构造嵌套张量期间始终复制列表中的张量。
- 参数
ts (Tensor or List[Tensor] or Tuple[Tensor]) – 要视为嵌套张量的张量,或具有相同 ndim 的张量列表/元组
- 关键字参数
dtype (
torch.dtype
, optional) – 返回嵌套张量的期望类型。默认值:如果为 None,则与列表中最左侧张量的torch.dtype
相同。device (
torch.device
, optional) – 返回嵌套张量的期望设备。默认值:如果为 None,则与列表中最左侧张量的torch.device
相同layout (
torch.layout
, optional) – 返回嵌套张量的期望布局。仅支持 stride 和 jagged 布局。默认值:如果为 None,则为 stride 布局。
- 返回类型
示例
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True) >>> b = torch.arange(5, dtype=torch.float, requires_grad=True) >>> nt = torch.nested.as_nested_tensor([a, b]) >>> nt.is_leaf False >>> fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)]) >>> nt.backward(fake_grad) >>> a.grad tensor([1., 1., 1.]) >>> b.grad tensor([0., 0., 0., 0., 0.]) >>> c = torch.randn(3, 5, requires_grad=True) >>> nt2 = torch.nested.as_nested_tensor(c)
- torch.nested.to_padded_tensor(input, padding, output_size=None, out=None) Tensor #
通过填充
input
嵌套张量来返回一个新的(非嵌套)张量。前导条目将填充嵌套数据,而尾随条目将被填充。警告
to_padded_tensor()
总是复制底层数据,因为嵌套张量和非嵌套张量在内存布局上不同。- 参数
padding (float) – 尾部条目的填充值。
- 关键字参数
示例
>>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))]) nested_tensor([ tensor([[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995]]), tensor([[-1.8546, -0.7194, -0.2918, -0.1846], [ 0.2773, 0.8793, -0.5183, -0.6447], [ 1.8009, 1.8468, -0.9832, -1.5272]]) ]) >>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0) tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], [[-1.8546, -0.7194, -0.2918, -0.1846, 0.0000], [ 0.2773, 0.8793, -0.5183, -0.6447, 0.0000], [ 1.8009, 1.8468, -0.9832, -1.5272, 0.0000]]]) >>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6)) tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276, 1.0000], [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995, 1.0000], [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], [[-1.8546, -0.7194, -0.2918, -0.1846, 1.0000, 1.0000], [ 0.2773, 0.8793, -0.5183, -0.6447, 1.0000, 1.0000], [ 1.8009, 1.8468, -0.9832, -1.5272, 1.0000, 1.0000], [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]]) >>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2)) RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported.
- torch.nested.masked_select(tensor, mask)[source]#
根据步幅张量输入和步幅掩码构造嵌套张量,生成的锯齿状布局嵌套张量将保留掩码为 True 的值。掩码的维度保持不变,并用偏移量表示,这与
masked_select()
不同,后者将输出折叠为一维张量。参数: tensor (
torch.Tensor
):用于构造锯齿形布局嵌套张量的步幅张量。mask (torch.Tensor
):应用于张量输入的步幅掩码张量。示例
>>> tensor = torch.randn(3, 3) >>> mask = torch.tensor([[False, False, True], [True, False, True], [False, False, True]]) >>> nt = torch.nested.masked_select(tensor, mask) >>> nt.shape torch.Size([3, j4]) >>> # Length of each item in the batch: >>> nt.offsets().diff() tensor([1, 2, 1]) >>> tensor = torch.randn(6, 5) >>> mask = torch.tensor([False]) >>> nt = torch.nested.masked_select(tensor, mask) >>> nt.shape torch.Size([6, j5]) >>> # Length of each item in the batch: >>> nt.offsets().diff() tensor([0, 0, 0, 0, 0, 0])
- 返回类型
- torch.nested.narrow(tensor, dim, start, length, layout=torch.strided)[source]#
从
tensor
(一个步幅张量)构造一个嵌套张量(可能是一个视图)。这遵循与 torch.Tensor.narrow 类似的语义,即在第dim
维度中,新的嵌套张量只显示区间 [start, start+length) 中的元素。由于嵌套表示允许在该维度的每一“行”具有不同的 start 和 length,因此start
和length
也可以是形状为 tensor.shape[0] 的张量。根据您使用的嵌套张量布局,存在一些差异。如果使用 strided 布局,torch.narrow 将把窄化数据复制到具有 strided 布局的连续 NT 中,而 jagged 布局的 narrow() 将创建原始 strided 张量的非连续视图。这种特殊的表示形式对于在 Transformer 模型中表示 kv-caches 非常有用,因为专门的 SDPA 内核可以轻松处理这种格式,从而提高性能。
- 参数
tensor (
torch.Tensor
) – 一个步幅张量,如果使用锯齿布局,它将用作嵌套张量的底层数据,或者如果使用步幅布局,它将被复制。dim (int) – 将应用窄化操作的维度。仅支持 dim=1 用于锯齿布局,而步幅布局支持所有维度
start (Union[int,
torch.Tensor
]) – 窄化操作的起始元素length (Union[int,
torch.Tensor
]) – 窄化操作期间获取的元素数量
- 关键字参数
layout (
torch.layout
, optional) – 返回嵌套张量的期望布局。仅支持 stride 和 jagged 布局。默认值:如果为 None,则为 stride 布局。- 返回类型
示例
>>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64) >>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64) >>> narrow_base = torch.randn(5, 10, 20) >>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged) >>> nt_narrowed.is_contiguous() False