ParameterDict#
- class torch.nn.modules.container.ParameterDict(parameters=None)[source]#
以字典形式保存参数。
ParameterDict 可以像普通 Python 字典一样进行索引,但其中包含的 Parameters 会被正确注册,并且对所有 Module 方法可见。其他对象将被视为普通 Python 字典进行处理。
ParameterDict
是一个有序字典。使用其他无序映射类型(例如 Python 的普通dict
)进行update()
操作不会保留合并映射的顺序。另一方面,使用OrderedDict
或另一个ParameterDict
会保留它们的顺序。请注意,构造函数、字典元素的赋值以及
update()
方法会将任何Tensor
转换为Parameter
。- 参数
values (iterable, optional) – 一个 (string : Any) 的映射(字典)或一个 (string, Any) 类型的键值对的可迭代对象
示例
class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.params = nn.ParameterDict( { "left": nn.Parameter(torch.randn(5, 10)), "right": nn.Parameter(torch.randn(5, 10)), } ) def forward(self, x, choice): x = self.params[choice].mm(x) return x
- copy()[source]#
返回此
ParameterDict
实例的副本。- 返回类型
- setdefault(key, default=None)[source]#
设置 ParameterDict 中键的默认值。
如果键在 ParameterDict 中,则返回其值。否则,插入 key 和参数 default 并返回 default。default 默认为 None。
- update(parameters)[source]#
使用
parameters
中的键值对更新ParameterDict
,并覆盖现有键。注意
如果
parameters
是OrderedDict
、ParameterDict
或键值对的可迭代对象,则其中新元素的顺序将被保留。- 参数
parameters (iterable) – 一个从 string 到
Parameter
的映射(字典),或一个 (string,Parameter
) 类型的键值对的可迭代对象