Tree¶
- class torchrl.data.Tree(count: 'int | torch.Tensor' = None, wins: 'int | torch.Tensor' = None, index: 'torch.Tensor | None' = None, hash: 'int | None' = None, node_id: 'int | None' = None, rollout: 'TensorDict | None' = None, node_data: 'TensorDict | None' = None, subtree: 'Tree' = None, _parent: 'weakref.ref | list[weakref.ref] | None' = None, specs: 'Composite | None' = None, *, batch_size, device=None, names=None)[source]¶
- property branching_action: torch.Tensor | TensorDictBase | None¶
返回分支到此特定节点的动作。
- 返回:
一个张量、tensordict 或 None,如果节点没有父节点。
另请参阅
当 rollout 数据包含单个步长时,这将等于
prev_action
。另请参阅
树中与给定节点(或观察)关联的 所有 动作
.
- cat(dim: int = 0, *, out=None)¶
将tensordicts沿给定维度连接成一个tensordict。
此调用等同于调用
torch.cat()
,但与 torch.compile 兼容。
- dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Self ¶
将tensordict保存到磁盘。
此函数是
memmap()
的代理。
- edges() list[tuple[int, int]] [source]¶
检索树中的边列表。
每条边都表示为一个节点 ID 的元组:父节点 ID 和子节点 ID。树使用广度优先搜索 (BFS) 进行遍历,以确保所有边都被访问。
- 返回:
一个元组列表,其中每个元组包含一个父节点 ID 和一个子节点 ID。
- classmethod fields()¶
返回一个描述此数据类的字段的元组。字段类型为 Field。
接受一个数据类或其实例。元组元素为 Field 类型。
- from_any(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None)¶
Recursively converts any object to a TensorDict.
注意
from_any
比常规的 TensorDict 构造函数限制更少。它可以使用自定义启发式方法将数据结构(如 dataclasses 或 tuples)转换为 tensordict。此方法可能会产生一些额外的开销,并在映射策略方面涉及更多主观选择。注意
This method recursively converts the input object to a TensorDict. If the object is already a TensorDict (or any similar tensor collection object), it will be returned as is.
- 参数:
obj – The object to be converted.
- 关键字参数:
auto_batch_size (bool, optional) – 如果
True
,将自动计算 batch size。默认为False
。batch_dims (int, optional) – 如果 auto_batch_size 为
True
,则定义输出 tensordict 应具有的维度数。默认为None
(每个级别的完整 batch size)。device (torch.device, optional) – 将创建 TensorDict 的设备。
batch_size (torch.Size, optional) – TensorDict 的批次大小。与
auto_batch_size
互斥。
- 返回:
A TensorDict representation of the input object.
Supported objects
数据类通过
from_dataclass()
(数据类将被转换为 TensorDict 实例,而不是 tensorclass)。命名元组通过
from_namedtuple()
。通过
from_dict()
的字典。元组通过
from_tuple()
。NumPy 的结构化数组通过
from_struct_array()
。HDF5 对象通过
from_h5()
。
- from_dataclass(*, dest_cls: Type | None = None, auto_batch_size: bool = False, batch_dims: int | None = None, as_tensorclass: bool = False, device: torch.device | None = None, batch_size: torch.Size | None = None)¶
Converts a dataclass into a TensorDict instance.
- 参数:
dataclass – The dataclass instance to be converted.
- 关键字参数:
dest_cls (tensorclass, optional) – 用于映射数据的 tensorclass 类型。如果未提供,则创建一个新类。如果
obj
是一个类型或 as_tensorclass 为 False,则无效。auto_batch_size (bool, optional) – 如果
True
,将自动确定并应用 batch size 到生成的 TensorDict。默认为False
。batch_dims (int, optional) – 如果
auto_batch_size
为True
,则定义输出 tensordict 应具有的维度数。默认为None
(每个级别的完整 batch size)。as_tensorclass (bool, optional) – 如果
True
,则将转换委托给自由函数from_dataclass()
,并返回一个 tensorclass(tensorclass()
)类型或实例,而不是 TensorDict。默认为False
。device (torch.device, optional) – 将创建 TensorDict 的设备。默认为
None
。batch_size (torch.Size, optional) – TensorDict 的批次大小。默认为
None
。
- 返回:
A TensorDict instance derived from the provided dataclass, unless as_tensorclass is True, in which case a tensor-compatible class or instance is returned.
- 抛出:
TypeError – 如果提供的输入不是数据类实例。
警告
This method is distinct from the free function from_dataclass and serves a different purpose. While the free function returns a tensor-compatible class or instance, this method returns a TensorDict instance.
注意
此方法创建一个新的 TensorDict 实例,其键对应于输入 dataclass 的字段。
结果 TensorDict 中的每个键都使用 `cls.from_any` 方法进行初始化。
auto_batch_size
选项允许自动确定批次大小并将其应用于结果 TensorDict。
- from_h5(*, mode: str = 'r', auto_batch_size: bool = False, batch_dims: int | None = None, batch_size: torch.Size | None = None)¶
从 h5 文件创建 PersistentTensorDict。
- 参数:
filename (str) – h5 文件的路径。
- 关键字参数:
mode (str, optional) – 读取模式。默认为
"r"
。auto_batch_size (bool, optional) – 如果
True
,将自动计算 batch size。默认为False
。batch_dims (int, optional) – 如果 auto_batch_size 为
True
,则定义输出 tensordict 应具有的维度数。默认为None
(每个级别的完整 batch size)。batch_size (torch.Size, optional) – TensorDict 的批次大小。默认为
None
。
- 返回:
输入 h5 文件的 PersistentTensorDict 表示。
示例
>>> td = TensorDict.from_h5("path/to/file.h5") >>> print(td) PersistentTensorDict( fields={ key1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), key2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- from_modules(*, as_module: bool = False, lock: bool = True, use_state_dict: bool = False, lazy_stack: bool = False, expand_identical: bool = False)¶
为 vmap 的 ensemable 学习/特征期望应用检索多个模块的参数。
- 参数:
modules (nn.Module 序列) – 要从中获取参数的模块。如果模块结构不同,则需要懒惰堆叠(请参阅下面的 `lazy_stack` 参数)。
- 关键字参数:
as_module (bool, optional) – 如果
True
,则返回一个TensorDictParams
实例,可用于将参数存储在torch.nn.Module
中。默认为False
。lock (bool, optional) – 如果
True
,则锁定的结果 tensordict。默认为True
。use_state_dict (bool, optional) –
如果为
True
,则将使用模块的 state-dict 并将其解压到具有模型树结构的 TensorDict 中。默认为False
。注意
这在使用 state-dict hook 时尤其有用。
lazy_stack (bool, optional) –
是否密集堆叠或懒惰堆叠参数。默认为
False
(密集堆叠)。注意
lazy_stack
和as_module
是互斥的特性。警告
懒惰输出和非懒惰输出之间有一个重要的区别:非懒惰输出将使用所需的批次大小重新实例化参数,而 `lazy_stack` 将仅将参数表示为懒惰堆叠。这意味着,虽然原始参数可以安全地传递给优化器(当 `lazy_stack=True` 时),但在设置为 `True` 时需要传递新参数。
警告
虽然使用 lazy stack 来保留原始参数引用可能很诱人,但请记住,每次调用
get()
时,lazy stack 都会执行堆栈操作。这将需要内存(参数大小的 N 倍,如果构建了图,则更多)和计算时间。这也意味着优化器将包含更多参数,像step()
或zero_grad()
这样的操作的执行将花费更长的时间。总的来说,lazy_stack
应仅限于极少数用例。expand_identical (bool, optional) – 如果
True
且相同的参数(相同标识)被堆叠到自身,则将返回该参数的扩展版本。在lazy_stack=True
时忽略此参数。
示例
>>> from torch import nn >>> from tensordict import TensorDict >>> torch.manual_seed(0) >>> empty_module = nn.Linear(3, 4, device="meta") >>> n_models = 2 >>> modules = [nn.Linear(3, 4) for _ in range(n_models)] >>> params = TensorDict.from_modules(*modules) >>> print(params) TensorDict( fields={ bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) >>> # example of batch execution >>> def exec_module(params, x): ... with params.to_module(empty_module): ... return empty_module(x) >>> x = torch.randn(3) >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> # since lazy_stack = False, backprop leaves the original params untouched >>> y.sum().backward() >>> assert params["weight"].grad.norm() > 0 >>> assert modules[0].weight.grad is None
当
lazy_stack=True
时,情况略有不同>>> params = TensorDict.from_modules(*modules, lazy_stack=True) >>> print(params) LazyStackedTensorDict( fields={ bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=None, is_shared=False, stack_dim=0) >>> # example of batch execution >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> y.sum().backward() >>> assert modules[0].weight.grad is not None
- from_namedtuple(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None)¶
递归地将命名元组转换为 TensorDict。
- 参数:
named_tuple – 要转换的命名元组实例。
- 关键字参数:
auto_batch_size (bool, optional) – 如果
True
,将自动计算 batch size。默认为False
。batch_dims (int, optional) – 如果
auto_batch_size
为True
,则定义输出 tensordict 应具有的维度数。默认为None
(每个级别的完整 batch size)。device (torch.device, optional) – 将创建 TensorDict 的设备。默认为
None
。batch_size (torch.Size, optional) – TensorDict 的批次大小。默认为
None
。
- 返回:
输入命名元组的 TensorDict 表示。
示例
>>> from tensordict import TensorDict >>> import torch >>> data = TensorDict({ ... "a_tensor": torch.zeros((3)), ... "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3]) >>> nt = data.to_namedtuple() >>> print(nt) GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!')) >>> TensorDict.from_namedtuple(nt, auto_batch_size=True) TensorDict( fields={ a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a_string: NonTensorData(data=zero!, batch_size=torch.Size([3]), device=None), a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- from_pytree(*, batch_size: torch.Size | None = None, auto_batch_size: bool = False, batch_dims: int | None = None)¶
将 pytree 转换为 TensorDict 实例。
此方法旨在尽可能保留 pytree 的嵌套结构。
其他非张量键将被添加,以跟踪每个级别的标识,从而提供内置的 pytree 到 tensordict 的双射转换 API。
当前接受的类包括列表、元组、命名元组和字典。
注意
对于字典,非 NestedKey 键会作为
NonTensorData
实例单独注册。注意
可转换为张量类型(如 int、float 或 np.ndarray)将被转换为 torch.Tensor 实例。请注意,此转换是满射的:将 tensordict 转换回 pytree 将无法恢复原始类型。
示例
>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key >>> class WeirdLookingClass: ... pass ... >>> weird_key = WeirdLookingClass() >>> # Make a pytree with tuple, lists, dict and namedtuple >>> pytree = ( ... [torch.randint(10, (3,)), torch.zeros(2)], ... { ... "tensor": torch.randn( ... 2, ... ), ... "td": TensorDict({"one": 1}), ... weird_key: torch.randint(10, (2,)), ... "list": [1, 2, 3], ... }, ... {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, ... ) >>> # Build a TensorDict from that pytree >>> td = TensorDict.from_pytree(pytree) >>> # Recover the pytree >>> pytree_recon = td.to_pytree() >>> # Check that the leaves match >>> def check(v1, v2): >>> assert (v1 == v2).all() >>> >>> torch.utils._pytree.tree_map(check, pytree, pytree_recon) >>> assert weird_key in pytree_recon[1]
- from_remote_init(group: 'ProcessGroup' | None = None, device: torch.device | None = None) Self ¶
从远程发送的元数据创建新的 tensordict 实例。
此类方法接收由 init_remote 发送的元数据,创建具有匹配形状和 dtype 的新 tensordict,然后异步接收实际的 tensordict 内容。
- 参数:
src (int) – 发送元数据的源进程的 rank。
group ("ProcessGroup", optional) – 要使用的进程组。默认为 None。
device (torch.device, 可选) – 用于张量运算的设备。默认为 None。
- 返回:
使用接收到的元数据和内容初始化的新 tensordict 实例。
- 返回类型:
TensorDict
另请参阅
发送进程应已调用 ~.init_remote 来发送元数据和内容。
- from_struct_array(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None) Self ¶
将结构化 numpy 数组转换为 TensorDict。
生成的 TensorDict 将与 numpy 数组共享相同的内存内容(这是一次零拷贝操作)。原地更改结构化 numpy 数组的值会影响 TensorDict 的内容。
注意
此方法执行零拷贝操作,这意味着生成的 TensorDict 将与输入的 numpy 数组共享相同的内存内容。因此,原地更改 numpy 数组的值会影响 TensorDict 的内容。
- 参数:
struct_array (np.ndarray) – 要转换的结构化 numpy 数组。
- 关键字参数:
auto_batch_size (bool, optional) – 如果
True
,将自动计算 batch size。默认为False
。batch_dims (int, optional) – 如果
auto_batch_size
为True
,则定义输出 tensordict 应具有的维度数。默认为None
(每个级别的完整 batch size)。device (torch.device, 可选) –
将创建 TensorDict 的设备。默认为
None
。注意
更改设备(即,指定任何非 `None` 或 `"cpu"` 的设备)将传输数据,从而导致返回数据的内存位置发生更改。
batch_size (torch.Size, 可选) – TensorDict 的批次大小。默认为 None。
- 返回:
输入的结构化 numpy 数组的 TensorDict 表示。
示例
>>> x = np.array( ... [("Rex", 9, 81.0), ("Fido", 3, 27.0)], ... dtype=[("name", "U10"), ("age", "i4"), ("weight", "f4")], ... ) >>> td = TensorDict.from_struct_array(x) >>> x_recon = td.to_struct_array() >>> assert (x_recon == x).all() >>> assert x_recon.shape == x.shape >>> # Try modifying x age field and check effect on td >>> x["age"] += 1 >>> assert (td["age"] == np.array([10, 4])).all()
- classmethod from_tensordict(tensordict: TensorDictBase, non_tensordict: dict | None = None, safe: bool = True) Self ¶
用于实例化新张量类对象的张量类包装器。
- 参数:
tensordict (TensorDictBase) – 张量类型字典
non_tensordict (dict) – 包含非张量和嵌套张量类对象的字典
safe (bool) – 如果 tensordict 不是 TensorDictBase 实例,是否引发错误
- from_tuple(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None)¶
将元组转换为 TensorDict。
- 参数:
obj – 要转换的元组实例。
- 关键字参数:
auto_batch_size (bool, optional) – 如果
True
,将自动计算 batch size。默认为False
。batch_dims (int, optional) – 如果 auto_batch_size 为
True
,则定义输出 tensordict 应具有的维度数。默认为None
(每个级别的完整 batch size)。device (torch.device, optional) – 将创建 TensorDict 的设备。默认为
None
。batch_size (torch.Size, optional) – TensorDict 的批次大小。默认为
None
。
- 返回:
输入的元组的 TensorDict 表示。
示例
>>> my_tuple = (1, 2, 3) >>> td = TensorDict.from_tuple(my_tuple) >>> print(td) TensorDict( fields={ 0: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), 1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), 2: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- fromkeys(value: Any = 0)¶
从键列表和单个值创建 tensordict。
- 参数:
keys (list of NestedKey) – 指定新字典键的可迭代对象。
value (compatible type, optional) – 所有键的值。默认为
0
。
- property full_action_spec¶
树的动作规范。
这是 Tree.specs[‘input_spec’, ‘full_action_spec’] 的别名。
- property full_done_spec¶
树的完成规范。
这是 Tree.specs[‘output_spec’, ‘full_done_spec’] 的别名。
- property full_observation_spec¶
树的观察规范。
这是 Tree.specs[‘output_spec’, ‘full_observation_spec’] 的别名。
- property full_reward_spec¶
树的奖励规范。
这是 Tree.specs[‘output_spec’, ‘full_reward_spec’] 的别名。
- property full_state_spec¶
树的状态规范。
这是 Tree.specs[‘input_spec’, ‘full_state_spec’] 的别名。
- get(key: NestedKey, *args, **kwargs)¶
获取输入键对应的存储值。
- 参数:
key (str, tuple of str) – 要查询的键。如果是字符串元组,则等同于链式调用 getattr。
default – 如果在张量类中找不到键,则返回默认值。
- 返回:
存储在输入键下的值
- property is_terminal: bool | torch.Tensor¶
如果没有子节点,则返回 True。
- lazy_stack(dim: int = 0, *, out=None, **kwargs)¶
创建 TensorDicts 的懒惰堆叠。
有关详细信息,请参阅
lazy_stack()
。
- load(*args, **kwargs) Self ¶
从磁盘加载 tensordict。
此类方法是
load_memmap()
的代理。
- load_(prefix: str | Path, *args, **kwargs)¶
在当前 tensordict 中从磁盘加载 tensordict。
此类方法是
load_memmap_()
的代理。
- load_memmap(device: torch.device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None) Self ¶
从磁盘加载内存映射的 tensordict。
- 参数:
prefix (str 或 Path to folder) – 应获取已保存 tensordict 的文件夹路径。
device (torch.device 或 等效项, 可选) – 如果提供,数据将异步转换为该设备。支持 `"meta"` 设备,在这种情况下,数据不会被加载,而是创建一组空的 "meta" 张量。这对于在不实际打开任何文件的情况下了解模型大小和结构很有用。
non_blocking (bool, optional) – 如果为
True
,则在将张量加载到设备后不会调用同步。默认为False
。out (TensorDictBase, optional) – 数据应写入的可选 tensordict。
示例
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
此方法还允许加载嵌套的 tensordicts。
示例
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
tensordict 也可以在“meta”设备上加载,或者作为假张量加载。
示例
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)¶
尝试将 state_dict 加载到目标张量类中(原地)。
- classmethod make_node(data: TensorDictBase, *, device: torch.device | None = None, batch_size: torch.Size | None = None, specs: Composite | None = None) Tree [source]¶
给定一些数据创建一个新节点。
- maybe_dense_stack(dim: int = 0, *, out=None, **kwargs)¶
尝试使 TensorDicts 密集堆叠,并在需要时回退到懒惰堆叠。
有关详细信息,请参阅
maybe_dense_stack()
。
- memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Self ¶
将所有张量写入内存映射的 Tensor 中,并放入新的 tensordict。
- 参数:
prefix (str) – 内存映射张量将存储的目录前缀。目录树结构将模仿 tensordict 的结构。
copy_existing (bool) – 如果为 False(默认值),则如果 tensordict 中的某个条目已经是存储在磁盘上的张量并具有关联文件,但未按 prefix 指定的位置保存,则会引发异常。如果为
True
,则任何现有张量都将被复制到新位置。
- 关键字参数:
num_threads (int, optional) – 用于写入 memmap 张量的线程数。默认为 0。
return_early (bool, optional) – 如果
True
且num_threads>0
,则该方法将返回 tensordict 的未来对象。share_non_tensor (bool, optional) – 如果为
True
,则非张量数据将在进程之间共享,并且单个节点内任何工作进程上的写入操作(例如原地更新或设置)将更新所有其他工作进程上的值。如果非张量叶子节点数量很高(例如,共享大量非张量数据),这可能会导致 OOM 或类似错误。默认为False
。existsok (bool, optional) – 如果为
False
,则如果同一路径下已存在张量,将引发异常。默认为True
。
然后,Tensordict 被锁定,这意味着任何非就地写入操作(例如重命名、设置或删除条目)都将引发异常。一旦 tensordict 被解锁,内存映射属性将变为
False
,因为不能保证跨进程身份。- 返回:
返回一个新的 tensordict,其中张量存储在磁盘上(如果
return_early=False
),否则返回一个TensorDictFuture
实例。
注意
以这种方式序列化对于深度嵌套的 tensordicts 来说可能很慢,因此不建议在训练循环中调用此方法。
- memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Self ¶
将所有张量原地写入相应的内存映射张量。
- 参数:
prefix (str) – 内存映射张量将存储的目录前缀。目录树结构将模仿 tensordict 的结构。
copy_existing (bool) – 如果为 False(默认值),则如果 tensordict 中的某个条目已经是存储在磁盘上的张量并具有关联文件,但未按 prefix 指定的位置保存,则会引发异常。如果为
True
,则任何现有张量都将被复制到新位置。
- 关键字参数:
num_threads (int, optional) – 用于写入 memmap 张量的线程数。默认为 0。
return_early (bool, optional) – 如果
True
且num_threads>0
,则该方法将返回 tensordict 的未来对象。可以使用 future.result() 查询结果 tensordict。share_non_tensor (bool, optional) – 如果为
True
,则非张量数据将在进程之间共享,并且单个节点内任何工作进程上的写入操作(例如原地更新或设置)将更新所有其他工作进程上的值。如果非张量叶子节点数量很高(例如,共享大量非张量数据),这可能会导致 OOM 或类似错误。默认为False
。existsok (bool, optional) – 如果为
False
,则如果同一路径下已存在张量,将引发异常。默认为True
。
然后,Tensordict 被锁定,这意味着任何非就地写入操作(例如重命名、设置或删除条目)都将引发异常。一旦 tensordict 被解锁,内存映射属性将变为
False
,因为不能保证跨进程身份。- 返回:
如果
return_early=False
,则返回 self,否则返回TensorDictFuture
实例。
注意
以这种方式序列化对于深度嵌套的 tensordicts 来说可能很慢,因此不建议在训练循环中调用此方法。
- memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Self ¶
创建一个无内容的内存映射 tensordict,其形状与原始 tensordict 相同。
- 参数:
prefix (str) – 内存映射张量将存储的目录前缀。目录树结构将模仿 tensordict 的结构。
copy_existing (bool) – 如果为 False(默认值),则如果 tensordict 中的某个条目已经是存储在磁盘上的张量并具有关联文件,但未按 prefix 指定的位置保存,则会引发异常。如果为
True
,则任何现有张量都将被复制到新位置。
- 关键字参数:
num_threads (int, optional) – 用于写入 memmap 张量的线程数。默认为 0。
return_early (bool, optional) – 如果
True
且num_threads>0
,则该方法将返回 tensordict 的未来对象。share_non_tensor (bool, optional) – 如果为
True
,则非张量数据将在进程之间共享,并且单个节点内任何工作进程上的写入操作(例如原地更新或设置)将更新所有其他工作进程上的值。如果非张量叶子节点数量很高(例如,共享大量非张量数据),这可能会导致 OOM 或类似错误。默认为False
。existsok (bool, optional) – 如果为
False
,则如果同一路径下已存在张量,将引发异常。默认为True
。
然后,Tensordict 被锁定,这意味着任何非就地写入操作(例如重命名、设置或删除条目)都将引发异常。一旦 tensordict 被解锁,内存映射属性将变为
False
,因为不能保证跨进程身份。- 返回:
如果
return_early=False
,则创建一个新的TensorDict
实例,其中数据存储为内存映射张量;否则,创建一个TensorDictFuture
实例。
注意
这是将一组大型缓冲区写入磁盘的推荐方法,因为
memmap_()
将复制信息,这对于大型内容来说可能很慢。示例
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()¶
如果内存映射的 tensordict 具有
saved_path
,则刷新其内容。如果没有任何路径与之关联,此方法将引发异常。
- property node_observation: torch.Tensor | TensorDictBase¶
返回与此特定节点关联的观察。
这是在出现分支之前定义节点的观察(或观察集)。如果节点包含
rollout()
属性,则节点观察通常与最后一次操作产生的观察相同,即node.rollout[..., -1]["next", "observation"]
。如果树规范关联了多个观察键,则会返回
TensorDict
实例。为了获得更一致的表示,请参见
node_observations
。
- property node_observations: torch.Tensor | TensorDictBase¶
以 TensorDict 格式返回与此特定节点关联的观察。
这是在出现分支之前定义节点的观察(或观察集)。如果节点包含
rollout()
属性,则节点观察通常与最后一次操作产生的观察相同,即node.rollout[..., -1]["next", "observation"]
。如果树规范关联了多个观察键,则会返回
TensorDict
实例。为了获得更一致的表示,请参见
node_observations
。
- property num_children: int¶
此节点的子节点数量。
等于
self.subtree
堆栈中元素的数量。
- num_vertices(*, count_repeat: bool = False) int [source]¶
返回 Tree 中唯一顶点的数量。
- 关键字参数:
count_repeat (bool, optional) –
确定是否计算重复的顶点。
如果为
False
,则每个唯一顶点只计算一次。如果为
True
,则如果顶点出现在不同的路径中,则多次计算它们。
默认为
False
。- 返回:
Tree 中唯一顶点的数量。
- 返回类型:
int
- property parent: Tree | None¶
节点的父节点。
如果节点有父节点且该对象仍在 python 工作区中,则此属性将返回该父节点。
对于重新分支的树,此属性可能会返回一个树堆栈,其中堆栈的每个索引对应于一个不同的父节点。
注意
parent
属性在内容上匹配,但在身份上不匹配:tensorclass 对象使用相同的张量(即指向相同内存位置的张量)进行重建。- 返回:
包含父节点数据的
Tree
,或者如果父节点数据超出范围或节点是根节点,则为None
。
- plot(backend: str = 'plotly', figure: str = 'tree', info: list[str] = None, make_labels: Callable[[Any, ...], Any] | None = None)[source]¶
使用指定的后端和图形类型绘制树的可视化。
- 参数:
backend – 要使用的绘图后端。目前仅支持 ‘plotly’。
figure – 要绘制的图形类型。可以是 ‘tree’ 或 ‘box’。
info – 要包含在图形中的其他信息列表(目前未使用)。
make_labels – 用于为图形生成自定义标签的可选函数。
- 抛出:
NotImplementedError – 如果指定了不受支持的后端或图形类型。
- property prev_action: torch.Tensor | TensorDictBase | None¶
在此节点观察生成之前的动作。
- 返回:
一个张量、tensordict 或 None,如果节点没有父节点。
另请参阅
当 rollout 数据包含单个步长时,这将等于
branching_action
。另请参阅
树中与给定节点(或观察)关联的 所有 动作
.
- rollout_from_path(path: tuple[int]) TensorDictBase | None [source]¶
沿给定路径检索树中的 rollout 数据。
对于路径中的每个节点,rollout 数据沿最后一个维度 (dim=-1) 连接。如果沿路径未找到 rollout 数据,则返回
None
。- 参数:
path – 表示树中路径的整数元组。
- 返回:
沿路径连接的 rollout 数据,或在未找到数据时返回 None。
- save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Self ¶
将tensordict保存到磁盘。
此函数是
memmap()
的代理。
- property selected_actions: torch.Tensor | TensorDictBase | None¶
返回从该节点分支出的所有选定动作的张量。
- set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)¶
设置一个新的键值对。
- 参数:
key (str, tuple of str) – 要设置的键的名称。如果是字符串元组,则等同于链式调用 getattr,然后进行最终的 setattr。
value (Any) – 要存储在张量类中的值
inplace (bool, optional) – 如果为
True
,则 set 将尝试原地更新值。如果为False
或键不存在,则值将简单地写入其目的地。
- 返回:
self
- stack(dim: int = 0, *, out=None)¶
沿给定维度将 tensordicts 堆叠成一个单一的 tensordict。
此调用等效于调用
torch.stack()
,但与 torch.compile 兼容。
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any] ¶
返回一个 state_dict 字典,可用于保存和加载张量类的数据。
- to_string(node_format_fn=<function Tree.<lambda>>)[source]¶
生成树的字符串表示。
此函数可以提取树中每个节点的信息,因此对调试很有用。节点逐行列出。每行包含到节点的路径,然后是使用 :arg:`node_format_fn` 生成的该节点的字符串表示。每行根据到达相应节点所需的路径步数进行缩进。
- 参数:
node_format_fn (Callable, optional) – 用户定义的函数,用于为树的每个节点生成字符串。签名必须为
(Tree) -> Any
,并且输出必须可转换为字符串。如果未提供此参数,则生成的字符串是节点Tree.node_data
属性转换为字典。
示例
>>> from torchrl.data import MCTSForest >>> from tensordict import TensorDict >>> forest = MCTSForest() >>> td_root = TensorDict({"observation": 0,}) >>> rollouts_data = [ ... # [(action, obs), ...] ... [(3, 123), (1, 456)], ... [(2, 359), (2, 3094)], ... [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)], ... [(1, 75)], ... [(3, 123), (0, 948)], ... [(2, 359), (2, 3094), (10, 68)], ... [(2, 359), (2, 3094), (11, 9045)], ... ] >>> for rollout_data in rollouts_data: ... td = td_root.clone().unsqueeze(0) ... for action, obs in rollout_data: ... td = td.update(TensorDict({ ... "action": [action], ... "next": TensorDict({"observation": [obs]}, [1]), ... }, [1])) ... forest.extend(td) ... td = td["next"].clone() ... >>> tree = forest.get_tree(td_root) >>> print(tree.to_string()) (0,) {'observation': tensor(123)} (0, 0) {'observation': tensor(456)} (0, 1) {'observation': tensor(847)} (0, 2) {'observation': tensor(948)} (1,) {'observation': tensor(3094)} (1, 0) {'observation': tensor(68)} (1, 1) {'observation': tensor(9045)} (2,) {'observation': tensor(75)}
- to_tensordict(*, retain_none: bool | None = None) TensorDict ¶
将张量类转换为常规 TensorDict。
复制所有条目。内存映射和共享内存张量将被转换为常规张量。
- 参数:
retain_none (bool) – 如果为
True
,则None
值将被写入 tensordict。否则它们将被丢弃。默认值:True
。- 返回:
包含与张量类相同值的新的 TensorDict 对象。
- unbind(dim: int)¶
返回沿指定维度解绑的索引张量类实例的元组。
结果张量类实例将共享初始张量类实例的存储。
- valid_paths()[source]¶
生成树中的所有有效路径。
有效路径是从根节点开始并以叶节点结束的子索引序列。每条路径表示为一个整数元组,其中每个整数对应于一个子节点的索引。
- 产生:
tuple – 树中的有效路径。
- vertices(*, key_type: Literal['id', 'hash', 'path'] = 'hash') dict[int | tuple[int], Tree] [source]¶
返回包含 Tree 顶点的映射。
- 关键字参数:
key_type (Literal["id", "hash", "path"], optional) –
指定用于顶点的键的类型。
”id”: 使用顶点 ID 作为键。
”hash”: 使用顶点的哈希作为键。
- ”path”: 使用顶点的路径作为键。这可能导致字典的长度比
当使用
"id"
或"hash"
时,因为相同的节点可能属于多个轨迹。默认为"hash"
。
默认为空字符串,这可能意味着默认行为。
- 返回:
将键映射到 Tree 顶点的字典。
- 返回类型:
Dict[int | Tuple[int], Tree]
- property visits: int | torch.Tensor¶
返回与此特定节点关联的访问次数。
这是
count
属性的别名。