set_list_to_stack¶
- class tensordict.set_list_to_stack(mode: bool)¶
上下文管理器和装饰器,用于控制 TensorDict 中列表处理的行为。
启用后,分配给 TensorDict 的列表将自动沿批次维度堆叠。这对于确保列表张量或其他元素在 TensorDict 内被视为可堆叠实体非常有用。
- 当前行为
- 如果在没有此上下文管理器的情况下将列表分配给 TensorDict,它将被转换为 numpy 数组
并在无法转换为 Tensor 时被包装在 NonTensorData 中。
- 未来行为
在 0.10.0 版本中,列表将默认自动堆叠。
- 参数:
mode (bool) – 如果为 True,则启用列表到堆叠的转换。如果为 False,则禁用它。
警告
- 如果列表在未设置此上下文管理器或全局标志的情况下分配给 TensorDict,则会引发 FutureWarning
表明未来行为将发生变化。
示例
>>> with set_list_to_stack(True): ... td = TensorDict(a=[torch.zeros(()), torch.ones(())], batch_size=2) ... assert (td["a"] == torch.tensor([0, 1])).all() ... assert td[0]["a"] == 0 ... assert td[1]["a"] == 1
另请参阅