快捷方式

Tree

class torchrl.data.Tree(count: 'int | torch.Tensor' = None, wins: 'int | torch.Tensor' = None, index: 'torch.Tensor | None' = None, hash: 'int | None' = None, node_id: 'int | None' = None, rollout: 'TensorDict | None' = None, node_data: 'TensorDict | None' = None, subtree: 'Tree' = None, _parent: 'weakref.ref | list[weakref.ref] | None' = None, specs: 'Composite | None' = None, *, batch_size, device=None, names=None)[源码]
property branching_action: torch.Tensor | TensorDictBase | None

返回分支到此特定节点的动作。

返回:

如果节点没有父节点,则返回一个张量、tensordict 或 None。

另请参阅

当 rollout 数据包含单个步骤时,这将等于 prev_action

另请参阅

树中给定节点(或观察)相关联的所有动作.

property device: device

检索张量类的设备类型。

dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

将tensordict保存到磁盘。

此函数是 memmap() 的代理。

edges() list[tuple[int, int]][源码]

检索树中的边列表。

每条边都表示为两个节点 ID 的元组:父节点 ID 和子节点 ID。树使用广度优先搜索 (BFS) 进行遍历,以确保访问所有边。

返回:

一个元组列表,其中每个元组包含一个父节点 ID 和一个子节点 ID。

classmethod fields()

返回一个描述此数据类的字段的元组。字段类型为 Field。

接受一个数据类或其实例。元组元素为 Field 类型。

classmethod from_tensordict(tensordict, non_tensordict=None, safe=True)

用于实例化新张量类对象的张量类包装器。

参数:
  • tensordict (TensorDict) – 张量类型的字典

  • non_tensordict (dict) – 包含非张量和嵌套张量类对象的字典

property full_action_spec

树的动作规范。

这是 Tree.specs[‘input_spec’, ‘full_action_spec’] 的别名。

property full_done_spec

树的完成规范。

这是 Tree.specs[‘output_spec’, ‘full_done_spec’] 的别名。

property full_observation_spec

树的观察规范。

这是 Tree.specs[‘output_spec’, ‘full_observation_spec’] 的别名。

property full_reward_spec

树的奖励规范。

这是 Tree.specs[‘output_spec’, ‘full_reward_spec’] 的别名。

property full_state_spec

树的状态规范。

这是 Tree.specs[‘input_spec’, ‘full_state_spec’] 的别名。

fully_expanded(env: EnvBase) bool[源码]

如果子节点数量等于环境基数,则返回 True。

get(key: NestedKey, *args, **kwargs)

获取输入键对应的存储值。

参数:
  • key (str, tuple of str) – 要查询的键。如果是字符串元组,则等同于链式调用 getattr。

  • default – 如果在张量类中找不到键,则返回默认值。

返回:

存储在输入键下的值

get_vertex_by_hash(hash: int) Tree[源码]

遍历树并返回与给定哈希对应的节点。

get_vertex_by_id(id: int) Tree[源码]

遍历树并返回与给定 ID 对应的节点。

property is_terminal: bool | torch.Tensor

如果树没有子节点,则返回 True。

classmethod load(prefix: str | Path, *args, **kwargs) T

从磁盘加载 tensordict。

此类方法是 load_memmap() 的代理。

load_(prefix: str | Path, *args, **kwargs)

在当前 tensordict 中从磁盘加载 tensordict。

此类方法是 load_memmap_() 的代理。

classmethod load_memmap(prefix: str | Path, device: torch.device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None) T

从磁盘加载内存映射的 tensordict。

参数:
  • prefix (strPath to folder) – 应获取已保存 tensordict 的文件夹路径。

  • device (torch.device等效可选) – 如果提供,数据将被异步转换为该设备。支持 “meta” 设备,在这种情况下,数据不会被加载,但会创建一组空的“meta”张量。这对于在不实际打开任何文件的情况下了解模型总大小和结构非常有用。

  • non_blocking (bool, optional) – 如果为 True,则在将张量加载到设备后不会调用同步。默认为 False

  • out (TensorDictBase, optional) – 数据应写入的可选 tensordict。

示例

>>> from tensordict import TensorDict
>>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0)
>>> td.memmap("./saved_td")
>>> td_load = TensorDict.load_memmap("./saved_td")
>>> assert (td == td_load).all()

此方法还允许加载嵌套的 tensordicts。

示例

>>> nested = TensorDict.load_memmap("./saved_td/nested")
>>> assert nested["e"] == 0

tensordict 也可以在“meta”设备上加载,或者作为假张量加载。

示例

>>> import tempfile
>>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}})
>>> with tempfile.TemporaryDirectory() as path:
...     td.save(path)
...     td_load = TensorDict.load_memmap(path, device="meta")
...     print("meta:", td_load)
...     from torch._subclasses import FakeTensorMode
...     with FakeTensorMode():
...         td_load = TensorDict.load_memmap(path)
...         print("fake:", td_load)
meta: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=meta,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=meta,
    is_shared=False)
fake: TensorDict(
    fields={
        a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)

尝试将 state_dict 加载到目标张量类中(原地)。

classmethod make_node(data: TensorDictBase, *, device: torch.device | None = None, batch_size: torch.Size | None = None, specs: Composite | None = None) Tree[源码]

给定一些数据创建一个新节点。

max_length()[源码]

返回树中所有有效路径的最大长度。

路径的长度定义为路径中的节点数。如果树为空,则返回 0。

返回:

树中所有有效路径的最大长度。

返回类型:

int

memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T

将所有张量写入内存映射的 Tensor 中,并放入新的 tensordict。

参数:
  • prefix (str) – 内存映射张量将存储的目录前缀。目录树结构将模仿 tensordict 的结构。

  • copy_existing (bool) – 如果为 False (默认),如果 tensordict 中的条目已经是存储在磁盘上的张量,并且其文件未按 prefix 正确保存,则会引发异常。如果为 True,则任何现有张量都会被复制到新位置。

关键字参数:
  • num_threads (int, optional) – 用于写入 memmap 张量的线程数。默认为 0

  • return_early (bool, optional) – 如果为 Truenum_threads>0,则该方法将返回 tensordict 的 future。

  • share_non_tensor (bool, optional) – 如果为 True,非张量数据将在进程之间共享,并且在同一节点内的任何工作器上的写入操作(例如原地更新或设置)将更新所有其他工作器上的值。如果非张量叶子数量很高(例如,共享大量非张量数据),这可能会导致 OOM 或类似错误。默认为 False

  • existsok (bool, optional) – 如果为 False,则如果同一路径中已存在张量,将引发异常。默认为 True

然后 TensorDict 被锁定,这意味着任何非原地操作都会引发异常(例如,重命名、设置或删除条目)。一旦 TensorDict 被解锁,内存映射属性将变为 False,因为不能保证跨进程的同一性。

返回:

如果 return_early=False,则返回一个带有存储在磁盘上的张量的新 tensordict,否则返回一个 TensorDictFuture 实例。

注意

以这种方式序列化对于深度嵌套的 tensordicts 来说可能很慢,因此不建议在训练循环中调用此方法。

memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T

将所有张量原地写入相应的内存映射张量。

参数:
  • prefix (str) – 内存映射张量将存储的目录前缀。目录树结构将模仿 tensordict 的结构。

  • copy_existing (bool) – 如果为 False (默认),如果 tensordict 中的条目已经是存储在磁盘上的张量,并且其文件未按 prefix 正确保存,则会引发异常。如果为 True,则任何现有张量都会被复制到新位置。

关键字参数:
  • num_threads (int, optional) – 用于写入 memmap 张量的线程数。默认为 0

  • return_early (bool, optional) – 如果为 Truenum_threads>0,则该方法将返回 tensordict 的 future。可以通过 future.result() 查询结果 tensordict。

  • share_non_tensor (bool, optional) – 如果为 True,非张量数据将在进程之间共享,并且在同一节点内的任何工作器上的写入操作(例如原地更新或设置)将更新所有其他工作器上的值。如果非张量叶子数量很高(例如,共享大量非张量数据),这可能会导致 OOM 或类似错误。默认为 False

  • existsok (bool, optional) – 如果为 False,则如果同一路径中已存在张量,将引发异常。默认为 True

然后 TensorDict 被锁定,这意味着任何非原地操作都会引发异常(例如,重命名、设置或删除条目)。一旦 TensorDict 被解锁,内存映射属性将变为 False,因为不能保证跨进程的同一性。

返回:

如果 return_early=False,则返回 self,否则返回 TensorDictFuture 实例。

注意

以这种方式序列化对于深度嵌套的 tensordicts 来说可能很慢,因此不建议在训练循环中调用此方法。

memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

创建一个无内容的内存映射 tensordict,其形状与原始 tensordict 相同。

参数:
  • prefix (str) – 内存映射张量将存储的目录前缀。目录树结构将模仿 tensordict 的结构。

  • copy_existing (bool) – 如果为 False (默认),如果 tensordict 中的条目已经是存储在磁盘上的张量,并且其文件未按 prefix 正确保存,则会引发异常。如果为 True,则任何现有张量都会被复制到新位置。

关键字参数:
  • num_threads (int, optional) – 用于写入 memmap 张量的线程数。默认为 0

  • return_early (bool, optional) – 如果为 Truenum_threads>0,则该方法将返回 tensordict 的 future。

  • share_non_tensor (bool, optional) – 如果为 True,非张量数据将在进程之间共享,并且在同一节点内的任何工作器上的写入操作(例如原地更新或设置)将更新所有其他工作器上的值。如果非张量叶子数量很高(例如,共享大量非张量数据),这可能会导致 OOM 或类似错误。默认为 False

  • existsok (bool, optional) – 如果为 False,则如果同一路径中已存在张量,将引发异常。默认为 True

然后 TensorDict 被锁定,这意味着任何非原地操作都会引发异常(例如,重命名、设置或删除条目)。一旦 TensorDict 被解锁,内存映射属性将变为 False,因为不能保证跨进程的同一性。

返回:

如果 return_early=False,则返回一个具有存储为内存映射张量的新 TensorDict 实例,否则返回一个 TensorDictFuture 实例。

注意

这是将一组大型缓冲区写入磁盘的推荐方法,因为 memmap_() 将会复制信息,这对于大型内容可能很慢。

示例

>>> td = TensorDict({
...     "a": torch.zeros((3, 64, 64), dtype=torch.uint8),
...     "b": torch.zeros(1, dtype=torch.int64),
... }, batch_size=[]).expand(1_000_000)  # expand does not allocate new memory
>>> buffer = td.memmap_like("/path/to/dataset")
memmap_refresh_()

如果内存映射的 tensordict 具有 saved_path,则刷新其内容。

如果没有任何路径与之关联,此方法将引发异常。

property node_observation: torch.Tensor | TensorDictBase

返回与此特定节点关联的观察。

这是在分支发生之前定义节点的观察(或观察集合)。如果节点包含 rollout() 属性,则节点观察通常与最后一次执行的动作产生的观察相同,即 node.rollout[..., -1]["next", "observation"]

如果树规范关联了多个观察键,则将返回一个 TensorDict 实例。

为了获得更一致的表示,请参阅 node_observations

property node_observations: torch.Tensor | TensorDictBase

以 TensorDict 格式返回与此特定节点关联的观察。

这是在分支发生之前定义节点的观察(或观察集合)。如果节点包含 rollout() 属性,则节点观察通常与最后一次执行的动作产生的观察相同,即 node.rollout[..., -1]["next", "observation"]

如果树规范关联了多个观察键,则将返回一个 TensorDict 实例。

为了获得更一致的表示,请参阅 node_observations

property num_children: int

此节点的子节点数量。

等于 self.subtree 堆栈中的元素数量。

num_vertices(*, count_repeat: bool = False) int[源码]

返回 Tree 中唯一顶点的数量。

关键字参数:

count_repeat (bool, optional) –

确定是否计算重复的顶点。

  • 如果为 False,则每个唯一顶点只计算一次。

  • 如果为 True,则如果顶点出现在不同的路径中,则会多次计算。

默认为False

返回:

Tree 中唯一顶点的数量。

返回类型:

int

property parent: Tree | None

节点的父节点。

如果节点有父节点且该对象仍在 Python 工作区中,则可以通过此属性返回。

对于重新分支的树,此属性可能会返回一个树堆栈,其中堆栈的每个索引对应一个不同的父节点。

注意

parent 属性在内容上将匹配,但在身份上不匹配:tensorclass 对象将使用相同的张量(即指向相同内存位置的张量)进行重建。

返回:

包含父节点数据的 Tree,或者如果父节点数据超出范围或节点是根节点,则为 None

plot(backend: str = 'plotly', figure: str = 'tree', info: list[str] = None, make_labels: Callable[[Any, ...], Any] | None = None)[源码]

使用指定的后端和图形类型绘制树的可视化。

参数:
  • backend – 要使用的绘图后端。目前仅支持 'plotly'。

  • figure – 要绘制的图形类型。可以是 'tree' 或 'box'。

  • info – 要包含在图形中的其他信息列表(目前未使用)。

  • make_labels – 用于为图形生成自定义标签的可选函数。

抛出:

NotImplementedError – 如果指定了不支持的后端或图形类型。

property prev_action: torch.Tensor | TensorDictBase | None

在此节点观察生成之前的动作。

返回:

如果节点没有父节点,则返回一个张量、tensordict 或 None。

另请参阅

当 rollout 数据包含单个步骤时,这将等于 branching_action

另请参阅

树中给定节点(或观察)相关联的所有动作.

rollout_from_path(path: tuple[int]) TensorDictBase | None[源码]

沿给定路径检索树中的 rollout 数据。

rollout 数据沿每个节点上的最后一个维度 (dim=-1) 连接。如果沿路径未找到 rollout 数据,则返回 None

参数:

path – 代表树中路径的整数元组。

返回:

沿路径连接的 rollout 数据,或在未找到数据时返回 None。

save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

将tensordict保存到磁盘。

此函数是 memmap() 的代理。

property selected_actions: torch.Tensor | TensorDictBase | None

返回一个包含此节点分支出的所有所选动作的张量。

set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)

设置一个新的键值对。

参数:
  • key (str, tuple of str) – 要设置的键的名称。如果是字符串元组,则等同于链式调用 getattr,然后进行最终的 setattr。

  • value (Any) – 要存储在张量类中的值

  • inplace (bool, optional) – 如果为 True,则 set 将尝试就地更新值。如果为 False 或密钥不存在,则值将简单地写入其目标。

返回:

self

state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any]

返回一个 state_dict 字典,可用于保存和加载张量类的数据。

to_string(node_format_fn=<function Tree.<lambda>>)[源码]

生成树的字符串表示。

此函数可以提取树中每个节点的信息,因此对于调试很有用。节点逐行列出。每行包含到节点的路径,后跟使用 :arg:`node_format_fn` 生成的该节点的字符串表示。每行根据到达相应节点所需的路径步数进行缩进。

参数:

node_format_fn (Callable, optional) – 用于为树的每个节点生成字符串的用户定义函数。签名必须是 (Tree) -> Any,并且输出必须可转换为字符串。如果未提供此参数,则生成的字符串是节点 Tree.node_data 属性转换为字典。

示例

>>> from torchrl.data import MCTSForest
>>> from tensordict import TensorDict
>>> forest = MCTSForest()
>>> td_root = TensorDict({"observation": 0,})
>>> rollouts_data = [
...     # [(action, obs), ...]
...     [(3, 123), (1, 456)],
...     [(2, 359), (2, 3094)],
...     [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)],
...     [(1, 75)],
...     [(3, 123), (0, 948)],
...     [(2, 359), (2, 3094), (10, 68)],
...     [(2, 359), (2, 3094), (11, 9045)],
... ]
>>> for rollout_data in rollouts_data:
...     td = td_root.clone().unsqueeze(0)
...     for action, obs in rollout_data:
...         td = td.update(TensorDict({
...             "action": [action],
...             "next": TensorDict({"observation": [obs]}, [1]),
...         }, [1]))
...         forest.extend(td)
...         td = td["next"].clone()
...
>>> tree = forest.get_tree(td_root)
>>> print(tree.to_string())
(0,) {'observation': tensor(123)}
(0, 0) {'observation': tensor(456)}
(0, 1) {'observation': tensor(847)}
(0, 2) {'observation': tensor(948)}
(1,) {'observation': tensor(3094)}
(1, 0) {'observation': tensor(68)}
(1, 1) {'observation': tensor(9045)}
(2,) {'observation': tensor(75)}
to_tensordict(*, retain_none: bool | None = None) TensorDict

将张量类转换为常规 TensorDict。

复制所有条目。内存映射和共享内存张量将被转换为常规张量。

参数:

retain_none (bool) – 如果为 True,则 None 值将被写入 tensordict。否则它们将被丢弃。默认值为 True

返回:

包含与张量类相同值的新的 TensorDict 对象。

unbind(dim: int)

返回沿指定维度解绑的索引张量类实例的元组。

结果张量类实例将共享初始张量类实例的存储。

valid_paths()[源码]

生成树中的所有有效路径。

有效路径是從根節點開始並以葉節點結束的子節點索引序列。每條路徑表示為一個整數元組,其中每個整數對應一個子節點的索引。

产生:

tuple – 树中的有效路径。

vertices(*, key_type: Literal['id', 'hash', 'path'] = 'hash') dict[int | tuple[int], Tree][源码]

返回一个包含 Tree 顶点的映射。

关键字参数:

key_type (Literal["id", "hash", "path"], optional) –

指定用于顶点的密钥类型。

  • “id”: 使用顶点 ID 作为密钥。

  • “hash”: 使用顶点的哈希作为密钥。

  • “path”: 使用到顶点的路径作为密钥。这可能导致字典的长度比

    当使用 "id""hash" 时,因为同一个节点可能属于多个轨迹。默认为 "hash"

默认为空字符串,可能意味着默认行为。

返回:

一个将密钥映射到 Tree 顶点的字典。

返回类型:

Dict[int | Tuple[int], Tree]

property visits: int | torch.Tensor

返回与此特定节点关联的访问次数。

这是 count 属性的别名。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源