评价此页

torch.optim.Optimizer.state_dict#

Optimizer.state_dict()[source]#

将优化器的状态作为 dict 返回。

它包含两个条目

  • state:一个包含当前优化状态的 Dict。其内容

    与优化器类之间存在差异,但有些通用特性是相同的。例如,状态是为每个参数保存的,但参数本身不被保存。 state 是一个字典,它将参数 id 映射到一个字典,该字典包含与每个参数相对应的状态。

  • param_groups:一个包含所有参数组的 List,其中每个

    参数组是一个字典。每个参数组包含优化器特有的元数据,例如学习率和权重衰减,以及属于该组的参数 ID 列表。如果参数组使用 named_parameters() 初始化,则名称内容也会保存在 state_dict 中。

注意:参数 ID 可能看起来像索引,但它们只是将状态与 param_group 关联的 ID。从 state_dict 加载时,优化器会将 param_group params(整数 ID)与优化器 param_groups(实际的 nn.Parameter)进行 zip 操作,以匹配状态,而无需额外验证。

返回的状态字典可能看起来像

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
    },
    'param_groups': [
        {
            'lr': 0.01,
            'weight_decay': 0,
            ...
            'params': [0]
            'param_names' ['param0']  (optional)
        },
        {
            'lr': 0.001,
            'weight_decay': 0.5,
            ...
            'params': [1, 2, 3]
            'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional)
        }
    ]
}
返回类型

dict[str, Any]