评价此页

torch.nn.utils.rnn.pack_padded_sequence#

torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[源码]#

将包含可变长度填充序列的 Tensor 打包。

input 的形状可以是 T x B x * (如果 batch_firstFalse) 或 B x T x * (如果 batch_firstTrue),其中 T 是最长序列的长度,B 是批次大小,而 * 是任意数量的维度(包括 0)。

对于未排序的序列,请使用 enforce_sorted = False。如果 enforce_sortedTrue,序列应该按长度递减排序,即 input[:,0] 应该是最长序列,input[:,B-1] 应该是最短序列。enforce_sorted = True 仅对 ONNX 导出是必需的。

它是 pad_packed_sequence() 的逆向操作,因此可以使用 pad_packed_sequence() 来恢复 PackedSequence 中打包的底层张量。

注意

此函数接受任何至少有两个维度的输入。您可以将其应用于打包标签,并使用 RNN 的输出来直接计算损失。可以通过访问 PackedSequence 对象的 .data 属性来检索张量。

参数
  • input (Tensor) – 可变长度序列的填充批次。

  • lengths (Tensorlist(int)) – 每个批次元素的序列长度列表(如果以 tensor 形式提供,则必须在 CPU 上)。

  • batch_first (bool, optional) – 如果为 True,则输入格式为 B x T x *;否则为 T x B x *。默认为 False

  • enforce_sorted (bool, optional) – 如果为 True,则输入应包含按长度递减排序的序列。如果为 False,则输入将无条件排序。默认为 True

返回类型

PackedSequence

警告

如果 input 张量的维度大于 length 中对应值,则该维度将被截断。

返回

一个 PackedSequence 对象

返回类型

PackedSequence