快捷方式

MultiAgentMLP

class torchrl.modules.MultiAgentMLP(n_agent_inputs: int | None, n_agent_outputs: int, n_agents: int, *, centralized: bool | None = None, share_params: bool | None = None, device: DEVICE_TYPING | None = None, depth: int | None = None, num_cells: Sequence | int | None = None, activation_class: type[nn.Module] | None = <class 'torch.nn.modules.activation.Tanh'>, use_td_params: bool = True, **kwargs)[源代码]

多智能体 MLP。

这是一个可在多智能体场景中使用的 MLP。例如,可用作策略或价值函数。请参阅 examples/multiagent 了解示例。

它期望输入形状为(*B, n_agents, n_agent_inputs),输出形状为(*B, n_agents, n_agent_outputs)

如果 share_params 为 True,则所有智能体将使用相同的 MLP 进行前向传播(同质策略)。否则,每个智能体将使用不同的 MLP 来处理其输入(异质策略)。

如果 centralized 为 True,每个智能体将使用所有智能体的输入来计算其输出(n_agent_inputs * n_agents 将是一个智能体的输入数量)。否则,每个智能体将仅使用其自身数据作为输入。

参数:
  • n_agent_inputs (intNone) – 每个智能体的输入数量。如果为 None,则输入数量在第一次调用时延迟实例化。

  • n_agent_outputs (int) – 每个智能体的输出数量。

  • n_agents (int) – 代理数量。

关键字参数:
  • centralized (bool) – 如果 centralized 为 True,每个智能体将使用所有智能体的输入来计算其输出(n_agent_inputs * n_agents 将是一个智能体的输入数量)。否则,每个智能体将仅使用其数据作为输入。

  • share_params (bool) – 如果 share_params 为 True,则所有智能体将使用相同的 MLP 进行前向传播(同质策略)。否则,每个智能体将使用不同的 MLP 来处理其输入(异质策略)。

  • device (strtoech.device, 可选) – 创建模块的设备。

  • depth (int, 可选) – 网络的深度。深度为 0 将生成一个具有所需输入和输出大小的单个线性层网络。长度为 1 将创建 2 个线性层,以此类推。如果未指示深度,则深度信息应包含在 num_cells 参数中(见下文)。如果 num_cells 是可迭代的且指示了深度,两者应匹配:len(num_cells) 必须等于 depth。默认值:3。

  • num_cells (intSequence[int], 可选) – 输入和输出之间的每层的单元数。如果提供整数,则每层具有相同的单元数。如果提供可迭代对象,则线性层的 out_features 将与 num_cells 的内容匹配。默认值:32。

  • activation_class (Type[nn.Module]) – 要使用的激活类。默认值:nn.Tanh。

  • use_td_params (bool, 可选) – 如果为 True,则参数可以在 self.params 中找到,它是一个 TensorDictParams 对象(它同时继承自 TensorDictnn.Module)。如果为 False,参数包含在 self._empty_net 中。总而言之,这两种方法应该大致相同,但不可互换:例如,使用 use_td_params=True 创建的 state_dict 不能在 use_td_params=False 时使用。

  • **kwargs – 可以传递给 torchrl.modules.models.MLP 以自定义 MLP。

注意

要使用 torch.nn.init 模块初始化 MARL 模块参数,请参阅 get_stateful_net()from_stateful_net() 方法。

示例

>>> from torchrl.modules import MultiAgentMLP
>>> import torch
>>> n_agents = 6
>>> n_agent_inputs=3
>>> n_agent_outputs=2
>>> batch = 64
>>> obs = torch.zeros(batch, n_agents, n_agent_inputs)
>>> # instantiate a local network shared by all agents (e.g. a parameter-shared policy)
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralized=False,
...     share_params=True,
...     depth=2,
... )
>>> print(mlp)
MultiAgentMLP(
  (agent_networks): ModuleList(
    (0): MLP(
      (0): Linear(in_features=3, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
Now let's instantiate a centralized network shared by all agents (e.g. a centalised value function)
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralized=True,
...     share_params=True,
...     depth=2,
... )
>>> print(mlp)
MultiAgentMLP(
  (agent_networks): ModuleList(
    (0): MLP(
      (0): Linear(in_features=18, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)
We can see that the input to the first layer is n_agents * n_agent_inputs,
this is because in the case the net acts as a centralized mlp (like a single huge agent)
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
Outputs will be identical for all agents.
Now we can do both examples just shown but with an independent set of parameters for each agent
Let's show the centralized=False case.
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralized=False,
...     share_params=False,
...     depth=2,
... )
>>> print(mlp)
MultiAgentMLP(
  (agent_networks): ModuleList(
    (0-5): 6 x MLP(
      (0): Linear(in_features=3, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)
We can see that this is the same as in the first example, but now we have 6 MLPs, one per agent!
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源