MCTSForest¶
- class torchrl.data.MCTSForest(*, data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, max_size: int | None = None, done_keys: list[NestedKey] | None = None, reward_keys: list[NestedKey] = None, observation_keys: list[NestedKey] = None, action_keys: list[NestedKey] = None, excluded_keys: list[NestedKey] = None, consolidated: bool | None = None)[source]¶
MCTS 树的集合。
警告
此类目前处于积极开发中。请注意 API 可能会频繁更改。
此类旨在将 rollouts 存储在存储中,并根据该数据集中给定的根生成树。
- 关键字参数:
data_map (TensorDictMap, 可选) – 用于存储数据(观测、奖励、状态等)的存储。如果未提供,它将使用
observation_keys
和action_keys
作为in_keys
,通过from_tensordict_pair()
惰性初始化。node_map (TensorDictMap, 可选) – 从观测空间到索引空间的映射。在内部,node_map 用于收集来自给定节点的所有可能分支。例如,如果一个观测在 data map 中有两个相关的动作和结果,那么
node_map
将返回一个包含data_map
中对应于这两个结果的数据结构。如果未提供,它将使用observation_keys
列表作为in_keys
,并使用QueryModule
作为out_keys
,通过from_tensordict_pair()
惰性初始化。max_size (int, 可选) – 映射的大小。如果未提供,则默认为
data_map.max_size
(如果可以找到),然后是node_map.max_size
。如果以上都没有提供,则默认为 1000。done_keys (NestedKey 列表, 可选) – 环境的 done 键。如果未提供,则默认为
("done", "terminated", "truncated")
。get_keys_from_env()
可用于自动确定键。action_keys (NestedKey 列表, 可选) – 环境的 action 键。如果未提供,则默认为
("action",)
。get_keys_from_env()
可用于自动确定键。reward_keys (NestedKey 列表, 可选) – 环境的 reward 键。如果未提供,则默认为
("reward",)
。get_keys_from_env()
可用于自动确定键。observation_keys (NestedKey 列表, 可选) – 环境的 observation 键。如果未提供,则默认为
("observation",)
。get_keys_from_env()
可用于自动确定键。excluded_keys (NestedKey 列表, 可选) – 要从数据存储中排除的键列表。
consolidated (bool, 可选) – 如果为
True
,则 data_map 存储将在磁盘上进行合并。默认为False
。
示例
>>> from torchrl.envs import GymEnv >>> import torch >>> from tensordict import TensorDict, LazyStackedTensorDict >>> from torchrl.data import TensorDictMap, ListStorage >>> from torchrl.data.map.tree import MCTSForest >>> >>> from torchrl.envs import PendulumEnv, CatTensors, UnsqueezeTransform, StepCounter >>> # Create the MCTS Forest >>> forest = MCTSForest() >>> # Create an environment. We're using a stateless env to be able to query it at any given state (like an oracle) >>> env = PendulumEnv() >>> obs_keys = list(env.observation_spec.keys(True, True)) >>> state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys) >>> # Appending transforms to get an "observation" key that concatenates the observations together >>> env = env.append_transform( ... UnsqueezeTransform( ... in_keys=obs_keys, ... out_keys=[("unsqueeze", key) for key in obs_keys], ... dim=-1 ... ) ... ) >>> env = env.append_transform( ... CatTensors([("unsqueeze", key) for key in obs_keys], "observation") ... ) >>> env = env.append_transform(StepCounter()) >>> env.set_seed(0) >>> # Get a reset state, then make a rollout out of it >>> reset_state = env.reset() >>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) >>> # Append the rollout to the forest. We're removing the state entries for clarity >>> rollout0 = rollout0.copy() >>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout0) >>> # The forest should have 6 elements (the length of the rollout) >>> assert len(forest) == 6 >>> # Let's make another rollout from the same reset state >>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) >>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout1) >>> assert len(forest) == 12 >>> # Let's make another final rollout from an intermediate step in the second rollout >>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next")) >>> rollout1b.exclude(*state_keys, inplace=True) >>> rollout1b.get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout1b) >>> assert len(forest) == 18 >>> # Since we have 2 rollouts starting at the same state, our tree should have two >>> # branches if we produce it from the reset entry. Take the state, and call `get_tree`: >>> r = rollout0[0] >>> # Let's get the compact tree that follows the initial reset. A compact tree is >>> # a tree where nodes that have a single child are collapsed. >>> tree = forest.get_tree(r) >>> print(tree.max_length()) 2 >>> print(list(tree.valid_paths())) [(0,), (1, 0), (1, 1)] >>> from tensordict import assert_close >>> # We can manually rebuild the tree >>> assert_close( ... rollout1, ... torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]), ... intersection=True, ... ) True >>> # Or we can rebuild it using the dedicated method >>> assert_close( ... rollout1, ... tree.rollout_from_path((1, 0)), ... intersection=True, ... ) True >>> tree.plot() >>> tree = forest.get_tree(r, compact=False) >>> print(tree.max_length()) 9 >>> print(list(tree.valid_paths())) [(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)] >>> assert_close( ... rollout1, ... tree.rollout_from_path((1, 0, 0, 0, 0, 0)), ... intersection=True, ... ) True
- property action_keys: list[tensordict._nestedkey.NestedKey]¶
动作键。
返回用于从环境输入中检索动作的键。默认的动作键是“action”。
- 返回:
字符串或元组列表,表示动作键。
- property done_keys: list[tensordict._nestedkey.NestedKey]¶
Done 键。
返回用于指示已结束的 episode 的键。默认的 done 键是“done”、“terminated”和“truncated”。这些键可以在环境的输出中使用来信号化 episode 的结束。
- 返回:
字符串列表,表示 done 键。
- extend(rollout, *, return_node: bool = False)[source]¶
向 forest 添加一个 rollout。
节点仅在 rollout 彼此分叉的点和 rollout 的终点添加到树中。
如果不存在与 rollout 的初始步骤匹配的现有树,则会添加一个新树。只会创建一个节点,用于最终步骤。
如果存在与 rollout 匹配的现有树,则将 rollout 添加到该树中。如果在某个步骤中 rollout 与树中的所有其他 rollout 分叉,则在 rollout 分叉的步骤之前创建一个新节点,并为 rollout 的最终步骤创建一个叶节点。如果 rollout 的所有步骤都与先前添加的 rollout 匹配,则不会发生任何变化。如果 rollout 匹配到树的叶节点但超出其范围,则该节点将扩展到 rollout 的末尾,并且不会创建新节点。
- 参数:
rollout (TensorDict) – 要添加到 forest 的 rollout。
return_node (bool, 可选) – 如果为
True
,则方法返回添加的节点。默认为False
。
- 返回:
- 添加到 forest 的节点。这仅
在
return_node
为 True 时返回。
- 返回类型:
示例
>>> from torchrl.data import MCTSForest >>> from tensordict import TensorDict >>> import torch >>> forest = MCTSForest() >>> r0 = TensorDict({ ... 'action': torch.tensor([1, 2, 3, 4, 5]), ... 'next': {'observation': torch.tensor([123, 392, 989, 809, 847])}, ... 'observation': torch.tensor([ 0, 123, 392, 989, 809]) ... }, [5]) >>> r1 = TensorDict({ ... 'action': torch.tensor([1, 2, 6, 7]), ... 'next': {'observation': torch.tensor([123, 392, 235, 38])}, ... 'observation': torch.tensor([ 0, 123, 392, 235]) ... }, [4]) >>> td_root = r0[0].exclude("next") >>> forest.extend(r0) >>> forest.extend(r1) >>> tree = forest.get_tree(td_root) >>> print(tree) Tree( count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False), index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), node_data=TensorDict( fields={ observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None), rollout=TensorDict( fields={ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), subtree=Tree( _parent=NonTensorStack( [<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x..., batch_size=torch.Size([2]), device=None), count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False), hash=NonTensorStack( [4341220243998689835, 6745467818783115365], batch_size=torch.Size([2]), device=None), node_data=LazyStackedTensorDict( fields={ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=cpu, is_shared=False, stack_dim=0), node_id=NonTensorStack( [1, 2], batch_size=torch.Size([2]), device=None), rollout=LazyStackedTensorDict( fields={ action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False), next: LazyStackedTensorDict( fields={ observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, -1]), device=cpu, is_shared=False, stack_dim=0), observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, -1]), device=cpu, is_shared=False, stack_dim=0), wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), index=None, subtree=None, specs=None, batch_size=torch.Size([2]), device=None, is_shared=False), wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), hash=None, _parent=None, specs=None, batch_size=torch.Size([]), device=None, is_shared=False)
- property observation_keys: list[tensordict._nestedkey.NestedKey]¶
Observation 键。
返回用于从环境输出中检索观测的键。默认的 observation 键是“observation”。
- 返回:
字符串或元组列表,表示 observation 键。
- property reward_keys: list[tensordict._nestedkey.NestedKey]¶
Reward 键。
返回用于从环境输出中检索奖励的键。默认的 reward 键是“reward”。
- 返回:
字符串或元组列表,表示 reward 键。
- to_string(td_root, node_format_fn=<function MCTSForest.<lambda>>)[source]¶
生成 forest 中树的字符串表示。
此函数可以提取树中每个节点的信息,因此对于调试很有用。节点按行显示。每行包含节点的路径,后跟使用 :arg:`node_format_fn` 生成的该节点的字符串表示。每行根据到达相应节点所需的路径步数进行缩进。
- 参数:
td_root (TensorDict) – 树的根节点。
node_format_fn (Callable, 可选) – 用户定义的函数,用于为树的每个节点生成字符串。签名必须是
(Tree) -> Any
,并且输出必须可转换为字符串。如果未提供此参数,则生成的字符串是节点的Tree.node_data
属性转换为字典。
示例
>>> from torchrl.data import MCTSForest >>> from tensordict import TensorDict >>> forest = MCTSForest() >>> td_root = TensorDict({"observation": 0,}) >>> rollouts_data = [ ... # [(action, obs), ...] ... [(3, 123), (1, 456)], ... [(2, 359), (2, 3094)], ... [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)], ... [(1, 75)], ... [(3, 123), (0, 948)], ... [(2, 359), (2, 3094), (10, 68)], ... [(2, 359), (2, 3094), (11, 9045)], ... ] >>> for rollout_data in rollouts_data: ... td = td_root.clone().unsqueeze(0) ... for action, obs in rollout_data: ... td = td.update(TensorDict({ ... "action": [action], ... "next": TensorDict({"observation": [obs]}, [1]), ... }, [1])) ... forest.extend(td) ... td = td["next"].clone() ... >>> print(forest.to_string(td_root)) (0,) {'observation': tensor(123)} (0, 0) {'observation': tensor(456)} (0, 1) {'observation': tensor(847)} (0, 2) {'observation': tensor(948)} (1,) {'observation': tensor(3094)} (1, 0) {'observation': tensor(68)} (1, 1) {'observation': tensor(9045)} (2,) {'observation': tensor(75)}