OnlineDTActor¶
- class torchrl.modules.OnlineDTActor(state_dim: int, action_dim: int, transformer_config: dict | DecisionTransformer.DTConfig = None, device: DEVICE_TYPING | None = None)[源代码]¶
在线决策 Transformer Actor 类。
在线决策 Transformer 的 Actor 类,用于从高斯分布中采样动作,如 “Online Decision Transformer” 所述。
返回用于从高斯分布中采样动作的均值和标准差。
- 参数:
state_dim (int) – 状态维度。
action_dim (int) – 动作维度。
transformer_config (Dict 或
DecisionTransformer.DTConfig
) – GPT2 transformer 的配置。默认为default_config()
。device (torch.device, optional) – 要使用的设备。默认为 None。
示例
>>> model = OnlineDTActor(state_dim=4, action_dim=2, ... transformer_config=OnlineDTActor.default_config()) >>> observation = torch.randn(32, 10, 4) >>> action = torch.randn(32, 10, 2) >>> return_to_go = torch.randn(32, 10, 1) >>> mu, std = model(observation, action, return_to_go) >>> mu.shape torch.Size([32, 10, 2]) >>> std.shape torch.Size([32, 10, 2])
- classmethod default_config()[源代码]¶
OnlineDTActor
的默认配置。
- forward(observation: Tensor, action: Tensor, return_to_go: Tensor) tuple[torch.Tensor, torch.Tensor] [源代码]¶
定义每次调用时执行的计算。
所有子类都应重写此方法。
注意
虽然 forward pass 的实现需要在该函数内定义,但之后应该调用
Module
实例而不是此函数,因为前者会处理已注册的钩子,而后者则会默默忽略它们。