评价此页

ParameterDict#

class torch.nn.modules.container.ParameterDict(parameters=None)[source]#

以字典形式保存参数。

ParameterDict 的索引方式与普通 Python 字典相同,但其中包含的 Parameters 会被正确注册,并且对所有 Module 方法可见。其他对象则按照普通 Python 字典的处理方式进行处理。

ParameterDict 是一个有序字典。使用其他无序映射类型(例如 Python 原生的 dict)调用 update() 时,不会保留合并映射的顺序。另一方面,OrderedDict 或另一个 ParameterDict 将保留其顺序。

请注意,构造函数、为字典元素赋值以及 update() 方法会将任何 Tensor 转换为 Parameter

参数:

values (iterable, optional) – 一个 (字符串: Any) 的映射(字典),或类型为 (字符串, Any) 的键值对可迭代对象

示例

class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.params = nn.ParameterDict(
            {
                "left": nn.Parameter(torch.randn(5, 10)),
                "right": nn.Parameter(torch.randn(5, 10)),
            }
        )

    def forward(self, x, choice):
        x = self.params[choice].mm(x)
        return x
clear()[source]#

移除 ParameterDict 中的所有项。

copy()[source]#

返回该 ParameterDict 实例的副本。

返回类型:

参数字典

fromkeys(keys, default=None)[source]#

返回一个包含所提供键的新 ParameterDict。

参数:
  • keys (iterable, string) – 用于创建新 ParameterDict 的键

  • default (Parameter, optional) – 为所有键设置的值

返回类型:

参数字典

get(key, default=None)[source]#

如果存在,返回与键关联的参数。否则,如果提供了 default,则返回该默认值,否则返回 None。

参数:
  • key (str) – 要从 ParameterDict 中获取的键

  • default (Parameter, optional) – 如果键不存在时返回的值

返回类型:

任何

items()[source]#

返回 ParameterDict 键值对的可迭代对象。

返回类型:

Iterable[tuple[str, Any]]

keys()[source]#

返回 ParameterDict 键的可迭代对象。

返回类型:

KeysView[str]

pop(key)[source]#

从 ParameterDict 中移除键并返回其参数。

参数:

key (str) – 要从 ParameterDict 中弹出的键

返回类型:

任何

popitem()[source]#

移除并返回 ParameterDict 中最后插入的 (key, parameter) 对。

返回类型:

tuple[str, Any]

setdefault(key, default=None)[source]#

为 ParameterDict 中的键设置默认值。

如果键存在于 ParameterDict 中,则返回其值。如果不存在,则插入 key 并将参数设置为 default,然后返回 defaultdefault 默认为 None

参数:
  • key (str) – 要设置默认值的键

  • default (Any) – 设置给该键的参数

返回类型:

任何

update(parameters)[source]#

使用 parameters 中的键值对更新 ParameterDict,覆盖现有键。

注意

如果 parametersOrderedDictParameterDict 或键值对的可迭代对象,则其中新元素的顺序会被保留。

参数:

parameters (iterable) – 从字符串到 Parameter 的映射(字典),或类型为 (字符串, Parameter) 的键值对可迭代对象

values()[source]#

返回 ParameterDict 值值的可迭代对象。

返回类型:

Iterable[Any]