ModuleDict#
- class torch.nn.ModuleDict(modules=None)[source]#
以字典形式保存子模块。
ModuleDict
可以像普通 Python 字典一样索引,但其中包含的模块会被正确注册,并且对所有Module
方法可见。ModuleDict
是一个**有序**字典,它遵循插入顺序,并且
在
update()
中,会保持合并的OrderedDict
、dict
(从 Python 3.6 开始) 或另一个ModuleDict
(作为update()
的参数) 的顺序。
请注意,使用其他无序映射类型 (例如,Python 3.6 之前的普通
dict
) 调用update()
不会保留合并映射的顺序。- 参数
modules (iterable, optional) – 一个 (string: module) 的映射 (字典) 或一个 (string, module) 类型键值对的可迭代对象
示例
class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.choices = nn.ModuleDict( {"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)} ) self.activations = nn.ModuleDict( [["lrelu", nn.LeakyReLU()], ["prelu", nn.PReLU()]] ) def forward(self, x, choice, act): x = self.choices[choice](x) x = self.activations[act](x) return x
- update(modules)[source]#
使用来自映射的键值对更新
ModuleDict
,并覆盖现有键。注意
如果
modules
是一个OrderedDict
、一个ModuleDict
,或一个键值对的可迭代对象,则其中新元素的顺序将被保留。