快捷方式

dense_stack_tds

class tensordict.dense_stack_tds(td_list: Union[Sequence[TensorDictBase], LazyStackedTensorDict], dim: Optional[int] = None)

将一系列 TensorDictBase 对象(或一个 LazyStackedTensorDict)进行密集堆叠,前提是它们具有相同的结构。

此函数接受一个 TensorDictBase 列表(直接传入或从 LazyStackedTensorDict 中获取)。与调用 `torch.stack(td_list)`(这将返回一个 LazyStackedTensorDict)不同,此函数会展开输入列表的第一个元素,并将输入列表堆叠到该元素上。这仅在输入列表的所有元素都具有相同结构时才有效。TensorDictBase 返回的类型将与输入列表元素的类型相同。

当需要堆叠的 TensorDictBase 对象中有 LazyStackedTensorDict,或者条目(或嵌套条目)中包含 LazyStackedTensorDict 时,此函数非常有用。在这些情况下,调用 `torch.stack(td_list).to_tensordict()` 是不可行的。因此,此函数为密集堆叠提供的列表提供了一种替代方案。

参数:
  • **td_list** (TensorDictBase 列表LazyStackedTensorDict) – 要堆叠的 tds。

  • **dim** (int, 可选) – 要堆叠的维度。如果 td_list 是 LazyStackedTensorDict,则会自动检索。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict import dense_stack_tds
>>> from tensordict.tensordict import assert_allclose_td
>>> td0 = TensorDict({"a": torch.zeros(3)},[])
>>> td1 = TensorDict({"a": torch.zeros(4), "b": torch.zeros(2)},[])
>>> td_lazy = torch.stack([td0, td1], dim=0)
>>> td_container = TensorDict({"lazy": td_lazy}, [])
>>> td_container_clone = td_container.clone()
>>> td_stack = torch.stack([td_container, td_container_clone], dim=0)
>>> td_stack
LazyStackedTensorDict(
    fields={
        lazy: LazyStackedTensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([2, 2]),
            device=None,
            is_shared=False,
            stack_dim=0)},
    exclusive_fields={
    },
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False,
    stack_dim=0)
>>> td_stack = dense_stack_tds(td_stack) # Automatically use the LazyStackedTensorDict stack_dim
TensorDict(
    fields={
        lazy: LazyStackedTensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
            exclusive_fields={
                1 ->
                    b: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2, 2]),
            device=None,
            is_shared=False,
            stack_dim=1)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
# Note that
# (1) td_stack is now a TensorDict
# (2) this has pushed the stack_dim of "lazy" (0 -> 1)
# (3) this has revealed the exclusive keys.
>>> assert_allclose_td(td_stack, dense_stack_tds([td_container, td_container_clone], dim=0))
# This shows it is the same to pass a list or a LazyStackedTensorDict

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源