快捷方式

QValueHook

class torchrl.modules.QValueHook(action_space: str, var_nums: int | None = None, action_value_key: NestedKey | None = None, action_mask_key: NestedKey | None = None, out_keys: Sequence[NestedKey] | None = None)[源]

Q值策略的 Q 值钩子。

给定一个常规 nn.Module 的输出,该输出表示不同离散动作的可用值,QValueHook 将把这些值转换为它们的 argmax 分量(即,结果贪婪动作)。

参数:
  • action_space (str) – 动作空间。必须是以下之一:"one-hot""mult-one-hot""binary""categorical"

  • var_nums (int, optional) – 如果 action_space = "mult-one-hot",则此值表示每个动作分量的基数。

  • action_value_key (str or tuple of str, optional) – 当钩接到 TensorDictModule 上使用。表示动作值的输入键。默认为 "action_value"

  • action_mask_key (str or tuple of str, optional) – 表示动作掩码的输入键。默认为 "None"(相当于无掩码)。

  • out_keys (list of str or tuple of str, optional) – 当钩接到 TensorDictModule 上使用。表示动作、动作值和所选动作值的输出键。默认为 ["action", "action_value", "chosen_action_value"]

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> module = nn.Linear(4, 4)
>>> hook = QValueHook("one_hot")
>>> module.register_forward_hook(hook)
>>> action_spec = OneHot(4)
>>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源