快捷方式

set_capture_non_tensor_stack

class tensordict.set_capture_non_tensor_stack(mode: bool)

一个上下文管理器或装饰器,用于控制是否将相同的非张量数据堆叠到单个 NonTensorData 对象或 NonTensorStack 中。

参数:

mode (bool) – 是否捕获非张量堆栈。如果为 False,则相同的非张量数据将被堆叠到 NonTensorStack 中。如果为 True,则单个 NonTensorData 对象将包含唯一值,但具有所需的批次大小。默认为 True

注意

自 v0.9 起,capture_non_tensor_stack() 默认返回 False。您可以通过以下方式设置 capture_non_tensor_stack() 的值

  • 环境变量 CAPTURE_NON_TENSOR_STACK

  • 通过在脚本开头设置 set_capture_non_tensor_stack(val: bool).set()

  • 通过将 set_capture_non_tensor_stack(val: bool) 用作上下文管理器或装饰器。

建议使用 set_capture_non_tensor_stack(False) 行为。

示例

>>> with set_capture_non_tensor_stack(False):
...     torch.stack([NonTensorData("a"), NonTensorData("a")])
NonTensorData("a", batch_size=[2])
>>> @set_capture_non_tensor_stack(False)
... def my_function():
...     return torch.stack([NonTensorData("a"), NonTensorData("a")])
>>> my_function()
NonTensorStack(["a", "a"], stack_dim=0)

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源