评价此页

torch.nn.utils.rnn.pack_padded_sequence#

torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[源代码]#

打包包含可变长度填充序列的 Tensor。

如果 batch_firstFalse,则 input 的大小可以是 T x B x *;如果 batch_firstTrue,则可以是 B x T x *,其中 T 是最长序列的长度,B 是批次大小,* 是任意数量的维度(包括 0)。

对于未排序的序列,请使用 enforce_sorted = False。如果 enforce_sortedTrue,则序列应按长度降序排列,即 input[:,0] 应该是最长的序列,而 input[:,B-1] 是最短的序列。enforce_sorted = True 仅在 ONNX 导出时才需要。

它是 pad_packed_sequence() 的逆操作,因此 pad_packed_sequence() 可用于恢复打包在 PackedSequence 中的基础张量。

注意

此函数接受至少具有两个维度的任何输入。您可以将其应用于打包标签,并使用 RNN 的输出与它们直接计算损失。可以通过访问 PackedSequence 对象的 .data 属性来从中检索张量。

参数
  • input (Tensor) – 变长序列的填充批次。

  • lengths (Tensorlist(int)) – 每个批次元素的序列长度列表(如果作为张量提供,则必须在 CPU 上)。

  • batch_first (bool, 可选) – 如果为 True,则输入应为 B x T x * 格式,否则为 T x B x *。默认值:False

  • enforce_sorted (bool, 可选) – 如果为 True,则输入应包含按长度降序排列的序列。如果为 False,则输入将无条件排序。默认值:True

返回类型

PackedSequence

警告

如果 input 张量的长度大于 length 中对应的值,则其维度将被截断。

返回

一个 PackedSequence 对象

返回类型

PackedSequence