ParameterDict#
- class torch.nn.modules.container.ParameterDict(parameters=None)[source]#
以字典形式保存参数。
ParameterDict 的索引方式与普通 Python 字典相同,但其中包含的 Parameters 会被正确注册,并且对所有 Module 方法可见。其他对象则按照普通 Python 字典的处理方式进行处理。
ParameterDict是一个有序字典。使用其他无序映射类型(例如 Python 原生的dict)调用update()时,不会保留合并映射的顺序。另一方面,OrderedDict或另一个ParameterDict将保留其顺序。请注意,构造函数、为字典元素赋值以及
update()方法会将任何Tensor转换为Parameter。- 参数:
values (iterable, optional) – 一个 (字符串: Any) 的映射(字典),或类型为 (字符串, Any) 的键值对可迭代对象
示例
class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.params = nn.ParameterDict( { "left": nn.Parameter(torch.randn(5, 10)), "right": nn.Parameter(torch.randn(5, 10)), } ) def forward(self, x, choice): x = self.params[choice].mm(x) return x
- copy()[source]#
返回该
ParameterDict实例的副本。- 返回类型:
- setdefault(key, default=None)[source]#
为 ParameterDict 中的键设置默认值。
如果键存在于 ParameterDict 中,则返回其值。如果不存在,则插入 key 并将参数设置为 default,然后返回 default。default 默认为 None。
- update(parameters)[source]#
使用
parameters中的键值对更新ParameterDict,覆盖现有键。注意
如果
parameters是OrderedDict、ParameterDict或键值对的可迭代对象,则其中新元素的顺序会被保留。- 参数:
parameters (iterable) – 从字符串到
Parameter的映射(字典),或类型为 (字符串,Parameter) 的键值对可迭代对象