快捷方式

pad_sequence

class tensordict.pad_sequence(list_of_tensordicts: Sequence[T], pad_dim: int = 0, padding_value: float = 0.0, out: Optional[T] = None, device: Optional[Union[device, str, int]] = None, return_mask: bool | tensordict._nestedkey.NestedKey = False)

对 tensordict 列表进行填充,以便将它们堆叠成连续格式。

参数:
  • list_of_tensordicts (List[TensorDictBase]) – 需要填充和堆叠的实例列表。

  • pad_dim (int, optional) – pad_dim 表示要填充 tensordict 中所有键的维度。默认为 0

  • padding_value (number, optional) – 填充值。默认为 0.0

  • out (TensorDictBase, optional) – 如果提供,则为写入数据的目标。

  • return_mask (boolNestedKey, optional) – 如果为 True,则会返回一个“masks”条目。如果 return_mask 是一个嵌套键(字符串或字符串元组),它将返回掩码并用作掩码条目的键。它包含一个与堆叠的 tensordict 结构相同的 tensordict,其中每个条目都包含有效值的掩码,大小为 torch.Size([stack_len, *new_shape]),其中 new_shape[pad_dim] = max_seq_length,其余 new_shape 与包含张量的先前形状匹配。

示例

>>> list_td = [
...     TensorDict({"a": torch.zeros((3, 8)), "b": torch.zeros((6, 8))}, batch_size=[]),
...     TensorDict({"a": torch.zeros((5, 8)), "b": torch.zeros((6, 8))}, batch_size=[]),
...     ]
>>> padded_td = pad_sequence(list_td, return_mask=True)
>>> print(padded_td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 5, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 6, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        masks: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 5]), device=cpu, dtype=torch.bool, is_shared=False),
                b: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源