set_capture_non_tensor_stack¶
- class tensordict.set_capture_non_tensor_stack(mode: bool)¶
一个上下文管理器或装饰器,用于控制是否将相同的非张量数据堆叠到单个 NonTensorData 对象或 NonTensorStack 中。
- 参数:
mode (bool) – 是否捕获非张量堆栈。如果为
False
,则相同的非张量数据将被堆叠到NonTensorStack
中。如果为True
,则单个NonTensorData
对象将包含唯一值,但具有所需的批次大小。默认为True
。
注意
自 v0.9 起,capture_non_tensor_stack() 默认返回 False。您可以通过以下方式设置
capture_non_tensor_stack()
的值环境变量
CAPTURE_NON_TENSOR_STACK
;通过在脚本开头设置
set_capture_non_tensor_stack(val: bool).set()
;通过将
set_capture_non_tensor_stack(val: bool)
用作上下文管理器或装饰器。
建议使用 set_capture_non_tensor_stack(False) 行为。
示例
>>> with set_capture_non_tensor_stack(False): ... torch.stack([NonTensorData("a"), NonTensorData("a")]) NonTensorData("a", batch_size=[2]) >>> @set_capture_non_tensor_stack(False) ... def my_function(): ... return torch.stack([NonTensorData("a"), NonTensorData("a")]) >>> my_function() NonTensorStack(["a", "a"], stack_dim=0)