ActorValueOperator¶
- class torchrl.modules.tensordict_module.ActorValueOperator(*args, **kwargs)[源代码]¶
Actor-value operator( actor-value 运算符)。
This class wraps together an actor and a value model that share a common observation embedding network(该类将共享公共观测嵌入网络的 actor 和 value model 包装在一起)。
注意
For a similar class that returns an action and a Quality value \(Q(s, a)\), see
ActorCriticOperator
. For a version without common embedding, refer toActorCriticWrapper
.(对于返回 action 和 Quality value \(Q(s, a)\) 的类似类,请参阅ActorCriticOperator
。对于没有通用嵌入的版本,请参考ActorCriticWrapper
。)To facilitate the workflow, this class comes with a get_policy_operator() and get_value_operator() methods, which will both return a standalone TDModule with the dedicated functionality.(为了简化工作流程,此类提供了 get_policy_operator() 和 get_value_operator() 方法,它们都将返回一个独立的 TDModule,具有专门的功能。)
- 参数:
common_operator (TensorDictModule) – a common operator that reads observations and produces a hidden variable(一个读取观测并生成隐藏变量的通用运算符)。
policy_operator (TensorDictModule) – 一个策略操作符,读取隐藏变量并返回一个动作
value_operator (TensorDictModule) – 一个值操作符,读取隐藏变量并返回一个值
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor, SafeModule >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamExtractor >>> module_hidden = torch.nn.Linear(4, 4) >>> td_module_hidden = SafeModule( ... module=module_hidden, ... in_keys=["observation"], ... out_keys=["hidden"], ... ) >>> module_action = TensorDictModule( ... nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()), ... in_keys=["hidden"], ... out_keys=["loc", "scale"], ... ) >>> td_module_action = ProbabilisticActor( ... module=module_action, ... in_keys=["loc", "scale"], ... out_keys=["action"], ... distribution_class=TanhNormal, ... return_log_prob=True, ... ) >>> module_value = torch.nn.Linear(4, 1) >>> td_module_value = ValueOperator( ... module=module_value, ... in_keys=["hidden"], ... ) >>> td_module = ActorValueOperator(td_module_hidden, td_module_action, td_module_value) >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) >>> td_clone = td_module(td.clone()) >>> print(td_clone) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> td_clone = td_module.get_policy_operator()(td.clone()) >>> print(td_clone) # no value TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> td_clone = td_module.get_value_operator()(td.clone()) >>> print(td_clone) # no action TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- get_policy_head() SafeSequential [源代码]¶
Returns the policy head.(返回策略头。)
- get_policy_operator() SafeSequential [源代码]¶
Returns a standalone policy operator that maps an observation to an action.(返回一个独立的策略运算符,该运算符将观测映射到动作。)
- get_value_head() SafeSequential [源代码]¶
Returns the value head.(返回价值头。)
- get_value_operator() SafeSequential [源代码]¶
返回一个独立的价值网络操作符,将观测映射到价值估计。