快捷方式

DecisionTransformerInferenceWrapper

class torchrl.modules.tensordict_module.DecisionTransformerInferenceWrapper(*args, **kwargs)[源代码]

Decision Transformer 的推理动作包装器。

一个专门为 Decision Transformer 设计的包装器,它将输入 tensordict 序列掩码到推理上下文中。输出将是一个 TensorDict,其键与输入相同,但只包含预测动作序列的最后一个动作和最后一个 return to go。

此模块创建一个修改后的 tensordict 副本,即它 **不** 就地修改 tensordict。

注意

如果 action、observation 或 reward-to-go 键不是标准的,则应使用方法 set_tensor_keys(),例如:

>>> dt_inference_wrapper.set_tensor_keys(action="foo", observation="bar", return_to_go="baz")

in_keys 是 observation、action 和 return-to-go 键。out_keys 与 in_keys 匹配,并添加了策略中的任何其他 out_key(例如,分布的参数或隐藏值)。

参数:

policy (TensorDictModule) – 接收观察并产生动作值的策略模块

关键字参数:
  • inference_context (int) – 上下文中不会被掩码的先前动作数量。例如,对于形状为 [batch_size, context, obs_dim] 的观察输入,其中 context=20 且 inference_context=5,上下文的前 15 个条目将被掩码。默认为 5。

  • spec (Optional[TensorSpec]) – 输入 TensorDict 的规范。如果为 None,将从策略模块中推断出来。

  • device (torch.device, optional) – 如果提供,则为缓冲区/规范放置的设备。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import (
...      ProbabilisticActor,
...      TanhDelta,
...      DTActor,
...      DecisionTransformerInferenceWrapper,
...  )
>>> dtactor = DTActor(state_dim=4, action_dim=2,
...             transformer_config=DTActor.default_config()
... )
>>> actor_module = TensorDictModule(
...         dtactor,
...         in_keys=["observation", "action", "return_to_go"],
...         out_keys=["param"])
>>> dist_class = TanhDelta
>>> dist_kwargs = {
...     "low": -1.0,
...     "high": 1.0,
... }
>>> actor = ProbabilisticActor(
...     in_keys=["param"],
...     out_keys=["action"],
...     module=actor_module,
...     distribution_class=dist_class,
...     distribution_kwargs=dist_kwargs)
>>> inference_actor = DecisionTransformerInferenceWrapper(actor)
>>> sequence_length = 20
>>> td = TensorDict({"observation": torch.randn(1, sequence_length, 4),
...                 "action": torch.randn(1, sequence_length, 2),
...                 "return_to_go": torch.randn(1, sequence_length, 1)}, [1,])
>>> result = inference_actor(td)
>>> print(result)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([1, 20, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        param: Tensor(shape=torch.Size([1, 20, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        return_to_go: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase = None) TensorDictBase[源代码]

定义每次调用时执行的计算。

所有子类都应重写此方法。

注意

虽然 forward pass 的实现需要在该函数中定义,但之后应该调用 Module 实例而不是直接调用它,因为前者负责运行注册的 hook,而后者则会静默地忽略它们。

mask_context(tensordict: TensorDictBase) TensorDictBase[源代码]

掩码输入序列的上下文。

set_tensor_keys(**kwargs)[源代码]

设置模块的输入键。

关键字参数:
  • observation (NestedKey, optional) – 观察键。

  • action (NestedKey, optional) – 动作键(输入到网络)。

  • return_to_go (NestedKey, optional) – return_to_go 键。

  • out_action (NestedKey, optional) – 动作键(输出自网络)。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源