快捷方式

MCTSForest

class torchrl.data.MCTSForest(*, data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, max_size: int | None = None, done_keys: list[NestedKey] | None = None, reward_keys: list[NestedKey] = None, observation_keys: list[NestedKey] = None, action_keys: list[NestedKey] = None, excluded_keys: list[NestedKey] = None, consolidated: bool | None = None)[source]

MCTS 树的集合。

警告

此类目前处于积极开发中。请预计 API 会频繁变动。

此类旨在将 rollouts 存储在存储中,并根据数据集中给定的根节点生成树。

关键字参数:
  • data_map (TensorDictMap, 可选) – 用于存储数据(观测、奖励、状态等)的存储。如果未提供,它将使用 `observation_keys` 和 `action_keys` 列表作为 `in_keys`,通过 from_tensordict_pair() 惰性初始化。

  • node_map (TensorDictMap, 可选) – 从观测空间到索引空间的映射。在内部,node_map 用于收集来自给定节点的所有可能的分支。例如,如果一个观测在 data map 中有两个关联的动作和结果,那么 `node_map` 将返回一个包含 data map 中对应于这两个结果的索引的数据结构。如果未提供,它将使用 `observation_keys` 列表作为 `in_keys` 和 QueryModule 作为 `out_keys`,通过 from_tensordict_pair() 惰性初始化。

  • max_size (int, 可选) – 映射的大小。如果未提供,则默认为 `data_map.max_size`(如果可找到),然后是 `node_map.max_size`。如果均未提供,则默认为 1000。

  • done_keys (NestedKey 列表, 可选) – 环境的完成键。如果未提供,则默认为 `("done", "terminated", "truncated")`。可以使用 get_keys_from_env() 自动确定键。

  • action_keys (NestedKey 列表, 可选) – 环境的动作键。如果未提供,则默认为 `("action",)`。可以使用 get_keys_from_env() 自动确定键。

  • reward_keys (NestedKey 列表, 可选) – 环境的奖励键。如果未提供,则默认为 `("reward",)`。可以使用 get_keys_from_env() 自动确定键。

  • observation_keys (NestedKey 列表, 可选) – 环境的观测键。如果未提供,则默认为 `("observation",)`。可以使用 get_keys_from_env() 自动确定键。

  • excluded_keys (NestedKey 列表, 可选) – 要从数据存储中排除的键列表。

  • consolidated (bool, 可选) – 如果为 `True`,则 `data_map` 存储将在磁盘上进行合并。默认为 `False`。

示例

>>> from torchrl.envs import GymEnv
>>> import torch
>>> from tensordict import TensorDict, LazyStackedTensorDict
>>> from torchrl.data import TensorDictMap, ListStorage
>>> from torchrl.data.map.tree import MCTSForest
>>>
>>> from torchrl.envs import PendulumEnv, CatTensors, UnsqueezeTransform, StepCounter
>>> # Create the MCTS Forest
>>> forest = MCTSForest()
>>> # Create an environment. We're using a stateless env to be able to query it at any given state (like an oracle)
>>> env = PendulumEnv()
>>> obs_keys = list(env.observation_spec.keys(True, True))
>>> state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys)
>>> # Appending transforms to get an "observation" key that concatenates the observations together
>>> env = env.append_transform(
...     UnsqueezeTransform(
...         in_keys=obs_keys,
...         out_keys=[("unsqueeze", key) for key in obs_keys],
...         dim=-1
...     )
... )
>>> env = env.append_transform(
...     CatTensors([("unsqueeze", key) for key in obs_keys], "observation")
... )
>>> env = env.append_transform(StepCounter())
>>> env.set_seed(0)
>>> # Get a reset state, then make a rollout out of it
>>> reset_state = env.reset()
>>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> # Append the rollout to the forest. We're removing the state entries for clarity
>>> rollout0 = rollout0.copy()
>>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout0)
>>> # The forest should have 6 elements (the length of the rollout)
>>> assert len(forest) == 6
>>> # Let's make another rollout from the same reset state
>>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1)
>>> assert len(forest) == 12
>>> # Let's make another final rollout from an intermediate step in the second rollout
>>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next"))
>>> rollout1b.exclude(*state_keys, inplace=True)
>>> rollout1b.get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1b)
>>> assert len(forest) == 18
>>> # Since we have 2 rollouts starting at the same state, our tree should have two
>>> #  branches if we produce it from the reset entry. Take the state, and call `get_tree`:
>>> r = rollout0[0]
>>> # Let's get the compact tree that follows the initial reset. A compact tree is
>>> #  a tree where nodes that have a single child are collapsed.
>>> tree = forest.get_tree(r)
>>> print(tree.max_length())
2
>>> print(list(tree.valid_paths()))
[(0,), (1, 0), (1, 1)]
>>> from tensordict import assert_close
>>> # We can manually rebuild the tree
>>> assert_close(
...     rollout1,
...     torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]),
...     intersection=True,
... )
True
>>> # Or we can rebuild it using the dedicated method
>>> assert_close(
...     rollout1,
...     tree.rollout_from_path((1, 0)),
...     intersection=True,
... )
True
>>> tree.plot()
>>> tree = forest.get_tree(r, compact=False)
>>> print(tree.max_length())
9
>>> print(list(tree.valid_paths()))
[(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)]
>>> assert_close(
...     rollout1,
...     tree.rollout_from_path((1, 0, 0, 0, 0, 0)),
...     intersection=True,
... )
True
property action_keys: list[tensordict._nestedkey.NestedKey]

动作键。

返回用于从环境输入中检索动作的键。默认动作键为“action”。

返回:

表示动作键的字符串或元组列表。

property done_keys: list[tensordict._nestedkey.NestedKey]

完成键。

返回用于指示剧集已结束的键。默认完成键为“done”、“terminated”和“truncated”。这些键可以在环境的输出中使用以信号传递剧集的结束。

返回:

表示完成键的字符串列表。

extend(rollout, *, return_node: bool = False)[source]

向森林添加一个 rollout。

节点仅在 rollout 分叉处以及 rollout 的端点处添加到树中。

如果没有现有的与 rollout 的初始步骤匹配的树,则会添加一个新树。仅创建一个节点,用于最后一步。

如果存在一个与 rollout 匹配的现有树,则将 rollout 添加到该树中。如果 rollout 在某个步骤与树中的所有其他 rollout 分叉,则在 rollout 分叉的步骤之前创建一个新节点,并为 rollout 的最后一步创建一个叶节点。如果 rollout 的所有步骤都与先前添加的 rollout 匹配,则不会发生任何变化。如果 rollout 匹配到树的叶节点但在此之后继续,则该节点将被扩展到 rollout 的末尾,而不会创建新节点。

参数:
  • rollout (TensorDict) – 要添加到森林的 rollout。

  • return_node (bool, 可选) – 如果为 `True`,则该方法返回添加的节点。默认为 `False`。

返回:

添加到森林的节点。这仅在

当 `return_node` 为 True 时返回。

返回类型:

示例

>>> from torchrl.data import MCTSForest
>>> from tensordict import TensorDict
>>> import torch
>>> forest = MCTSForest()
>>> r0 = TensorDict({
...     'action': torch.tensor([1, 2, 3, 4, 5]),
...     'next': {'observation': torch.tensor([123, 392, 989, 809, 847])},
...     'observation': torch.tensor([  0, 123, 392, 989, 809])
... }, [5])
>>> r1 = TensorDict({
...     'action': torch.tensor([1, 2, 6, 7]),
...     'next': {'observation': torch.tensor([123, 392, 235,  38])},
...     'observation': torch.tensor([  0, 123, 392, 235])
... }, [4])
>>> td_root = r0[0].exclude("next")
>>> forest.extend(r0)
>>> forest.extend(r1)
>>> tree = forest.get_tree(td_root)
>>> print(tree)
Tree(
    count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
    index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
    node_data=TensorDict(
        fields={
            observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([]),
        device=cpu,
        is_shared=False),
    node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
    rollout=TensorDict(
        fields={
            action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
            next: TensorDict(
                fields={
                    observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
                batch_size=torch.Size([2]),
                device=cpu,
                is_shared=False),
            observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([2]),
        device=cpu,
        is_shared=False),
    subtree=Tree(
        _parent=NonTensorStack(
            [<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x...,
            batch_size=torch.Size([2]),
            device=None),
        count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
        hash=NonTensorStack(
            [4341220243998689835, 6745467818783115365],
            batch_size=torch.Size([2]),
            device=None),
        node_data=LazyStackedTensorDict(
            fields={
                observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False,
            stack_dim=0),
        node_id=NonTensorStack(
            [1, 2],
            batch_size=torch.Size([2]),
            device=None),
        rollout=LazyStackedTensorDict(
            fields={
                action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False),
                next: LazyStackedTensorDict(
                    fields={
                        observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
                    exclusive_fields={
                    },
                    batch_size=torch.Size([2, -1]),
                    device=cpu,
                    is_shared=False,
                    stack_dim=0),
                observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([2, -1]),
            device=cpu,
            is_shared=False,
            stack_dim=0),
        wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        index=None,
        subtree=None,
        specs=None,
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False),
    wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
    hash=None,
    _parent=None,
    specs=None,
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
get_keys_from_env(env: EnvBase)[source]

根据环境,向 Forest 写入缺失的完成、动作和奖励键。

现有键不会被覆盖。

property observation_keys: list[tensordict._nestedkey.NestedKey]

观测键。

返回用于从环境输出中检索观测的键。默认观测键为“observation”。

返回:

表示观测键的字符串或元组列表。

property reward_keys: list[tensordict._nestedkey.NestedKey]

奖励键。

返回用于从环境输出中检索奖励的键。默认奖励键为“reward”。

返回:

表示奖励键的字符串或元组列表。

to_string(td_root, node_format_fn=<function MCTSForest.<lambda>>)[source]

生成森林中树的字符串表示。

此函数可以提取树中每个节点的信息,因此对于调试很有用。节点逐行列出。每行包含节点的路径,然后是使用 :arg:`node_format_fn` 生成的该节点的字符串表示。每行根据到达相应节点所需的路径长度进行缩进。

参数:
  • td_root (TensorDict) – 树的根节点。

  • node_format_fn (Callable, 可选) – 用于生成树的每个节点字符串的用户定义函数。签名必须为 (Tree) -> Any,并且输出必须可转换为字符串。如果未提供此参数,则生成的字符串是节点的 `Tree.node_data` 属性转换为字典。

示例

>>> from torchrl.data import MCTSForest
>>> from tensordict import TensorDict
>>> forest = MCTSForest()
>>> td_root = TensorDict({"observation": 0,})
>>> rollouts_data = [
...     # [(action, obs), ...]
...     [(3, 123), (1, 456)],
...     [(2, 359), (2, 3094)],
...     [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)],
...     [(1, 75)],
...     [(3, 123), (0, 948)],
...     [(2, 359), (2, 3094), (10, 68)],
...     [(2, 359), (2, 3094), (11, 9045)],
... ]
>>> for rollout_data in rollouts_data:
...     td = td_root.clone().unsqueeze(0)
...     for action, obs in rollout_data:
...         td = td.update(TensorDict({
...             "action": [action],
...             "next": TensorDict({"observation": [obs]}, [1]),
...         }, [1]))
...         forest.extend(td)
...         td = td["next"].clone()
...
>>> print(forest.to_string(td_root))
(0,) {'observation': tensor(123)}
(0, 0) {'observation': tensor(456)}
(0, 1) {'observation': tensor(847)}
(0, 2) {'observation': tensor(948)}
(1,) {'observation': tensor(3094)}
(1, 0) {'observation': tensor(68)}
(1, 1) {'observation': tensor(9045)}
(2,) {'observation': tensor(75)}

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源