torch.optim.Optimizer.load_state_dict#
- Optimizer.load_state_dict(state_dict)[源代码]#
加载优化器状态。
- 参数
state_dict (dict) – 优化器状态。应该是一个从调用
state_dict()
返回的对象。
警告
请确保在初始化
torch.optim.lr_scheduler.LRScheduler
后调用此方法,因为在此之前调用会覆盖加载的学习率。注意
参数的名称(如果存在于
state_dict()
中每个参数组的“param_names”键下)不会影响加载过程。要使用参数名称进行自定义(例如,当加载的状态字典中的参数与优化器中初始化的参数不同时),应实现自定义的register_load_state_dict_pre_hook
来相应地调整加载的字典。如果param_names
存在于加载的状态字典param_groups
中,它们将被保存并覆盖优化器状态中当前存在的名称。如果它们不存在于加载的状态字典中,优化器的param_names
将保持不变。示例
>>> model = torch.nn.Linear(10, 10) >>> optim = torch.optim.SGD(model.parameters(), lr=3e-4) >>> scheduler1 = torch.optim.lr_scheduler.LinearLR( ... optim, ... start_factor=0.1, ... end_factor=1, ... total_iters=20, ... ) >>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR( ... optim, ... T_max=80, ... eta_min=3e-5, ... ) >>> lr = torch.optim.lr_scheduler.SequentialLR( ... optim, ... schedulers=[scheduler1, scheduler2], ... milestones=[20], ... ) >>> lr.load_state_dict(torch.load("./save_seq.pt")) >>> # now load the optimizer checkpoint after loading the LRScheduler >>> optim.load_state_dict(torch.load("./save_optim.pt"))