快捷方式

ConvNet

class torchrl.modules.ConvNet(in_features: int | None = None, depth: int | None = None, num_cells: Sequence[int] | int = None, kernel_sizes: Sequence[int] | int = 3, strides: Sequence[int] | int = 1, paddings: Sequence[int] | int = 0, activation_class: type[nn.Module] | Callable = <class 'torch.nn.modules.activation.ELU'>, activation_kwargs: dict | list[dict] | None = None, norm_class: type[nn.Module] | Callable | None = None, norm_kwargs: dict | list[dict] | None = None, bias_last_layer: bool = True, aggregator_class: type[nn.Module] | Callable | None = <class 'torchrl.modules.models.utils.SquashDims'>, aggregator_kwargs: dict | None = None, squeeze_output: bool = False, device: DEVICE_TYPING | None = None)[源]

一个卷积神经网络。

参数:
  • in_features (int, optional) – 输入特征的数量。如果为 None,则第一个层将使用 LazyConv2d 模块;

  • depth (int, optional) – 网络深度。深度为 1 将产生一个具有所需输入大小的单个线性层网络,其输出大小等于 num_cells 参数的最后一个元素。如果未指定深度,则深度信息应包含在 num_cells 参数中(见下文)。如果 num_cells 是可迭代的,并且指定了 depth,则两者应匹配:len(num_cells) 必须等于 depth

  • num_cells (intSequence of int, optional) – 输入和输出之间每层的单元数。如果提供整数,则每层将具有相同的单元数。如果提供可迭代对象,则线性层的 out_features 将与 num_cells 的内容匹配。默认为 [32, 32, 32]

  • kernel_sizes (int, sequence of int, optional) – 卷积网络的核大小。如果为可迭代对象,则长度必须与由 num_cells 或 depth 参数定义的深度匹配。默认为 3

  • strides (intsequence of int, optional) – 卷积网络的步幅。如果为可迭代对象,则长度必须与由 num_cells 或 depth 参数定义的深度匹配。默认为 1

  • activation_class (Type[nn.Module] 或 callable, optional) – 要使用的激活类或构造函数。默认为 Tanh

  • activation_kwargs (dictlist of dicts, optional) – 要与激活类一起使用的 kwargs。也可以传递一个长度为 depth 的 kwargs 列表,每个层一个元素。

  • norm_class (Typecallable, optional) – 归一化类或构造函数(如果存在)。

  • norm_kwargs (dictlist of dicts, optional) – 要与归一化层一起使用的 kwargs。也可以传递一个长度为 depth 的 kwargs 列表,每个层一个元素。

  • bias_last_layer (bool) – 如果为 True,最后一个线性层将具有偏置参数。默认为 True

  • aggregator_class (Type[nn.Module] 或 callable) – 在链的末尾使用的聚合器类或构造函数。默认为 torchrl.modules.utils.models.SquashDims

  • aggregator_kwargs (dict, optional) – aggregator_class 的 kwargs。

  • squeeze_output (bool) – 输出是否应被压缩其单例维度。默认为 False

  • device (torch.device, optional) – 创建模块的设备。

示例

>>> # All of the following examples provide valid, working MLPs
>>> cnet = ConvNet(in_features=3, depth=1, num_cells=[32,]) # MLP consisting of a single 3 x 6 linear layer
>>> print(cnet)
ConvNet(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ELU(alpha=1.0)
  (2): SquashDims()
)
>>> cnet = ConvNet(in_features=3, depth=4, num_cells=32)
>>> print(cnet)
ConvNet(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ELU(alpha=1.0)
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (3): ELU(alpha=1.0)
  (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (5): ELU(alpha=1.0)
  (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (7): ELU(alpha=1.0)
  (8): SquashDims()
)
>>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35])  # defines the depth by the num_cells arg
>>> print(cnet)
ConvNet(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ELU(alpha=1.0)
  (2): Conv2d(32, 33, kernel_size=(3, 3), stride=(1, 1))
  (3): ELU(alpha=1.0)
  (4): Conv2d(33, 34, kernel_size=(3, 3), stride=(1, 1))
  (5): ELU(alpha=1.0)
  (6): Conv2d(34, 35, kernel_size=(3, 3), stride=(1, 1))
  (7): ELU(alpha=1.0)
  (8): SquashDims()
)
>>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35], kernel_sizes=[3, 4, 5, (2, 3)])  # defines kernels, possibly rectangular
>>> print(cnet)
ConvNet(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ELU(alpha=1.0)
  (2): Conv2d(32, 33, kernel_size=(4, 4), stride=(1, 1))
  (3): ELU(alpha=1.0)
  (4): Conv2d(33, 34, kernel_size=(5, 5), stride=(1, 1))
  (5): ELU(alpha=1.0)
  (6): Conv2d(34, 35, kernel_size=(2, 3), stride=(1, 1))
  (7): ELU(alpha=1.0)
  (8): SquashDims()
)
classmethod default_atari_dqn(num_actions: int)[源]

返回标志性 DQN 论文中提出的默认 DQN。

参数:

num_actions (int) – atari 游戏的动作空间。

forward(inputs: Tensor) Tensor[源]

定义每次调用时执行的计算。

所有子类都应重写此方法。

注意

虽然前向传播的实现需要在该函数内部定义,但之后应该调用 Module 实例而不是此函数,因为前者会处理已注册的钩子,而后者会默默地忽略它们。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源