set_list_to_stack¶
- class tensordict.set_list_to_stack(mode: (Python v3.13))¶
用于控制 TensorDict 中列表处理行为的上下文管理器和装饰器。
启用后,分配给 TensorDict 的列表将自动沿着批次维度堆叠。这对于确保列表中的张量或其他元素在 TensorDict 中被视为可堆叠实体非常有用。
- 当前行为
- 如果未通过此上下文管理器将列表分配给 TensorDict,它将被转换为 numpy 数组
并包装在 NonTensorData 中,如果它无法转换为张量。
- 参数:
mode ((Python v3.13)bool) – 如果为 True,则启用列表到堆叠的转换。如果为 False,则禁用它。
示例
>>> with set_list_to_stack(True): ... td = TensorDict(a=[torch.zeros(()), torch.ones(())], batch_size=2) ... assert (td["a"] == torch.tensor([0, 1])).all() ... assert td[0]["a"] == 0 ... assert td[1]["a"] == 1
另请参阅