快捷方式

set_list_to_stack

class tensordict.set_list_to_stack(mode: bool)

上下文管理器和装饰器,用于控制 TensorDict 中列表处理的行为。

启用后,分配给 TensorDict 的列表将自动沿批次维度堆叠。这对于确保列表张量或其他元素在 TensorDict 内被视为可堆叠实体非常有用。

当前行为
如果在没有此上下文管理器的情况下将列表分配给 TensorDict,它将被转换为 numpy 数组

并在无法转换为 Tensor 时被包装在 NonTensorData 中。

未来行为

在 0.10.0 版本中,列表将默认自动堆叠。

参数:

mode (bool) – 如果为 True,则启用列表到堆叠的转换。如果为 False,则禁用它。

警告

如果列表在未设置此上下文管理器或全局标志的情况下分配给 TensorDict,则会引发 FutureWarning

表明未来行为将发生变化。

示例

>>> with set_list_to_stack(True):
...     td = TensorDict(a=[torch.zeros(()), torch.ones(())], batch_size=2)
...     assert (td["a"] == torch.tensor([0, 1])).all()
...     assert td[0]["a"] == 0
...     assert td[1]["a"] == 1

另请参阅

list_to_stack().

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源