torch.optim#
创建于: Jun 13, 2025 | 最后更新于: Aug 24, 2025
torch.optim 是一个实现了各种优化算法的包。
最常用的方法已经得到支持,并且接口足够通用,以便将来可以轻松地集成更复杂的方法。
如何使用优化器#
要使用 torch.optim,您必须构造一个优化器对象,该对象将保存当前状态并根据计算出的梯度更新参数。
构造优化器#
要构造一个 Optimizer,您必须提供一个包含要优化的参数(所有参数都应该是 Parameter)或命名参数(元组 (str, Parameter))的可迭代对象。然后,您可以指定特定于优化器的选项,例如学习率、权重衰减等。
示例
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
命名参数示例
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001)
每个参数的选项#
Optimizer 还支持为每个参数指定选项。为此,请将 Variable 的可迭代对象替换为 dict 的可迭代对象。每个字典将定义一个单独的参数组,并应包含一个 params 键,其中包含属于该组的参数列表。其他键应匹配优化器接受的关键字参数,并将用作该组的优化选项。
例如,这在一个人想要为每个层指定学习率时非常有用
optim.SGD([
{'params': model.base.parameters(), 'lr': 1e-2},
{'params': model.classifier.parameters()}
], lr=1e-3, momentum=0.9)
optim.SGD([
{'params': model.base.named_parameters(), 'lr': 1e-2},
{'params': model.classifier.named_parameters()}
], lr=1e-3, momentum=0.9)
这意味着 model.base 的参数将使用学习率为 1e-2,而 model.classifier 的参数将保持默认学习率 1e-3。最后,所有参数将使用 0.9 的动量。
注意
您仍然可以将选项作为关键字参数传递。它们将用作默认值,用于那些未覆盖它们的组。当您只想更改一个选项,同时保持所有其他选项在参数组之间一致时,这非常有用。
另外,请考虑以下与参数区分惩罚相关的示例。请记住,parameters() 返回一个包含所有可学习参数的可迭代对象,包括偏置和其他可能需要区分惩罚的参数。为了解决这个问题,可以为每个参数组指定单独的惩罚权重
bias_params = [p for name, p in self.named_parameters() if 'bias' in name]
others = [p for name, p in self.named_parameters() if 'bias' not in name]
optim.SGD([
{'params': others},
{'params': bias_params, 'weight_decay': 0}
], weight_decay=1e-2, lr=1e-2)
通过这种方式,偏置项与非偏置项分开,并且偏置项的 weight_decay 设置为 0,以避免对该组进行任何惩罚。
执行优化步骤#
所有优化器都实现了一个 step() 方法,该方法更新参数。它可以两种方式使用
optimizer.step()#
这是大多数优化器支持的简化版本。在使用例如 backward() 计算出梯度后,可以调用该函数一次。
示例
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.step(closure)#
一些优化算法,例如共轭梯度和 LBFGS,需要多次重新评估函数,因此您必须传入一个闭包,允许它们重新计算模型。闭包应清除梯度,计算损失并返回它。
示例
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
return loss
optimizer.step(closure)
基类#
- class torch.optim.Optimizer(params, defaults)[source]#
所有优化器的基类。
警告
参数必须指定为具有确定性顺序的可迭代对象,并且在运行之间保持一致。不满足这些属性的对象示例是集合和字典值的迭代器。
- 参数:
params (iterable) – 一个
torch.Tensor或dict的可迭代对象。指定哪些 Tensor 应该被优化。defaults (dict[str, Any]) – (dict): 一个包含优化选项默认值的字典(当参数组未指定时使用)。
将一个参数组添加到 |
|
加载优化器状态。 |
|
注册一个 load_state_dict 前置钩子,该钩子将在调用 |
|
注册一个 load_state_dict 后置钩子,该钩子将在调用 |
|
将优化器的状态作为 |
|
注册一个 state dict 前置钩子,该钩子将在调用 |
|
注册一个 state dict 后置钩子,该钩子将在调用 |
|
执行一次优化步骤来更新参数。 |
|
注册一个优化器步骤预钩子,它将在优化器步骤之前被调用。 |
|
注册一个优化器步骤后钩子,它将在优化器步骤之后被调用。 |
|
重置所有优化 |
算法#
实现了 Adadelta 算法。 |
|
实现了 Adafactor 算法。 |
|
实现了 Adagrad 算法。 |
|
实现了 Adam 算法。 |
|
实现了 AdamW 算法,其中权重衰减不累积到动量或方差中。 |
|
SparseAdam 实现了一个 Adam 算法的掩码版本,适用于稀疏梯度。 |
|
实现了 Adamax 算法(基于无穷范数的 Adam 变体)。 |
|
实现了平均随机梯度下降。 |
|
实现了 L-BFGS 算法。 |
|
实现了 Muon 算法。 |
|
实现了 NAdam 算法。 |
|
实现了 RAdam 算法。 |
|
实现了 RMSprop 算法。 |
|
实现了弹性反向传播算法。 |
|
实现了随机梯度下降(可选带动量)。 |
我们的许多算法都有针对性能、可读性和/或通用性进行优化的各种实现,因此我们会尝试在用户未指定任何特定实现的情况下,默认使用当前设备上通用的最快实现。
我们有 3 种主要的实现类别:for-loop(循环)、foreach(多张量)和 fused(融合)。最直接的实现是在参数上进行循环,并进行大量计算。循环通常比我们的 foreach 实现慢,foreach 将参数组合成一个多张量,并一次性进行大量计算,从而节省了许多顺序内核调用。我们的一些优化器甚至有更快的融合实现,它们将大量计算融合到一个内核中。我们可以将 foreach 实现视为横向融合,将融合实现视为在此基础上进行纵向融合。
通常,这三种实现的性能排序是 fused > foreach > for-loop。因此,在适用时,我们默认使用 foreach 而非 for-loop。适用意味着 foreach 实现可用,用户未指定任何特定于实现的 kwargs(例如,fused、foreach、differentiable),并且所有张量都是原生的。请注意,虽然融合应该比 foreach 更快,但这些实现较新,并且在全面启用之前,我们希望让它们有更多的时间进行完善。我们在下表第二个表格中总结了每种实现的稳定性状态,欢迎您尝试!
下表显示了每种算法的可用实现和默认实现
算法 |
默认 |
有 foreach? |
有 fused? |
|---|---|---|---|
foreach |
是 |
否 |
|
for-loop |
否 |
否 |
|
foreach |
是 |
是(仅限 CPU) |
|
foreach |
是 |
是 |
|
foreach |
是 |
是 |
|
for-loop |
否 |
否 |
|
foreach |
是 |
否 |
|
foreach |
是 |
否 |
|
for-loop |
否 |
否 |
|
for-loop |
否 |
否 |
|
foreach |
是 |
否 |
|
foreach |
是 |
否 |
|
foreach |
是 |
否 |
|
foreach |
是 |
否 |
|
foreach |
是 |
是 |
下表显示了融合实现的稳定性状态
算法 |
CPU |
CUDA |
MPS |
|---|---|---|---|
不支持 |
不支持 |
不支持 |
|
不支持 |
不支持 |
不支持 |
|
beta |
不支持 |
不支持 |
|
beta |
稳定 |
beta |
|
beta |
稳定 |
beta |
|
不支持 |
不支持 |
不支持 |
|
不支持 |
不支持 |
不支持 |
|
不支持 |
不支持 |
不支持 |
|
不支持 |
不支持 |
不支持 |
|
不支持 |
不支持 |
不支持 |
|
不支持 |
不支持 |
不支持 |
|
不支持 |
不支持 |
不支持 |
|
不支持 |
不支持 |
不支持 |
|
不支持 |
不支持 |
不支持 |
|
beta |
beta |
beta |
如何调整学习率#
torch.optim.lr_scheduler.LRScheduler 提供几种根据 epoch 数量调整学习率的方法。torch.optim.lr_scheduler.ReduceLROnPlateau 允许根据某些验证测量动态降低学习率。
学习率调度应在优化器更新后应用;例如,您应该这样编写代码
示例
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler.step()
大多数学习率调度器可以连续调用(也称为链式调度器)。结果是每个调度器按顺序应用于前一个调度器获得学习率。
示例
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler1.step()
scheduler2.step()
在文档的许多地方,我们将使用以下模板来引用调度算法。
>>> scheduler = ...
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
警告
在 PyTorch 1.1.0 之前,学习率调度器预计在优化器更新之前调用;1.1.0 以向后不兼容的方式更改了此行为。如果您在升级到 PyTorch 1.1.0 后使用学习率调度器(在调用 optimizer.step() 之前调用 scheduler.step()),这将跳过学习率计划的第一个值。如果您在升级到 PyTorch 1.1.0 后无法重现结果,请检查是否在错误的时间调用了 scheduler.step()。
所有学习率调度器的基类。 |
|
设置初始学习率。 |
|
将每个参数组的学习率乘以指定函数中给出的因子。 |
|
每隔 step_size 个 epoch,将每个参数组的学习率乘以 gamma 进行衰减。 |
|
当 epoch 数量达到 milestones 中的一个时,将每个参数组的学习率乘以 gamma 进行衰减。 |
|
将每个参数组的学习率乘以一个小的常数因子。 |
|
通过线性改变小的乘法因子,衰减每个参数组的学习率。 |
|
每个 epoch 将每个参数组的学习率乘以 gamma 进行衰减。 |
|
使用给定 total_iters 中的多项式函数衰减每个参数组的学习率。 |
|
使用余弦退火调度设置每个参数组的学习率。 |
|
链接一系列学习率调度器。 |
|
包含一系列在优化过程中应按顺序调用的调度器。 |
|
当指标停止改进时,降低学习率。 |
|
根据周期性学习率策略(CLR)为每个参数组设置学习率。 |
|
根据 1cycle 学习率策略为每个参数组设置学习率。 |
|
使用余弦退火调度设置每个参数组的学习率。 |
如何使用命名参数加载优化器状态字典#
函数 load_state_dict() 如果存在,将存储加载的 state dict 中的可选 param_names 内容。然而,加载优化器状态的过程不受影响,因为参数的顺序对于保持兼容性(在顺序不同的情况下)很重要。要利用加载的 state dict 中的加载的参数名称,需要根据所需的行为实现自定义的 register_load_state_dict_pre_hook。
例如,当模型架构发生变化,但权重和优化器状态需要保持不变时,这可能很有用。以下示例演示了如何实现此自定义。
示例
class OneLayerModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3, 4)
def forward(self, x):
return self.fc(x)
model = OneLayerModel()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)
假设 model 实现了一个专家(MoE),我们想复制它并为两个专家恢复训练,这两个专家都以与 fc 层相同的方式初始化。对于以下 model2,我们创建了两个与 fc 相同的层,并通过将 model 的模型权重和优化器状态加载到 model2 的 fc1 和 fc2 中来恢复训练(并相应地调整它们)。
class TwoLayerModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 4)
self.fc2 = nn.Linear(3, 4)
def forward(self, x):
return (self.fc1(x) + self.fc2(x)) / 2
model2 = TwoLayerModel()
# adapt and load model weights..
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)
为了加载 optimizer2 的 state dict(其 state dict 来自之前的优化器),使得 fc1 和 fc2 都将使用 fc 的副本进行初始化(以从 fc 恢复训练),我们可以使用以下钩子
def adapt_state_dict_ids(optimizer, state_dict):
adapted_state_dict = deepcopy(optimizer.state_dict())
# Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
for k, v in state_dict['param_groups'][0].items():
if k not in ['params', 'param_names']:
adapted_state_dict['param_groups'][0][k] = v
lookup_dict = {
'fc1.weight': 'fc.weight',
'fc1.bias': 'fc.bias',
'fc2.weight': 'fc.weight',
'fc2.bias': 'fc.bias'
}
clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
for param_id, param_name in zip(
optimizer.state_dict()['param_groups'][0]['params'],
optimizer.state_dict()['param_groups'][0]['param_names']):
name_in_loaded = lookup_dict[param_name]
index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
# Copy the state of the corresponding parameter
if id_in_loaded in state_dict['state']:
adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])
return adapted_state_dict
optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict
这确保了在模型加载期间将使用适应后的 state_dict,其中包含 model2 的层所需的状态。请注意,此代码专门为此示例设计(例如,假设只有一个参数组),其他情况可能需要不同的适应。
以下示例展示了如何在模型结构发生变化时处理加载的 state dict 中缺失的参数。Model_bypass 添加了一个新的 bypass 层,该层在原始 Model1 中不存在。为了恢复训练,使用了自定义的 adapt_state_dict_missing_param 钩子来适应优化器的 state_dict,确保现有参数被正确映射,而缺失的参数(如 bypass 层)则保持不变(如本例中的初始化)。这种方法使得即使模型发生变化,也能平滑地加载和恢复优化器状态。新的 bypass 层将从头开始训练。
class Model1(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 5)
def forward(self, x):
return self.fc(x) + x
model = Model1()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)
class Model_bypass(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 5)
self.bypass = nn.Linear(5, 5, bias=False)
torch.nn.init.eye_(self.bypass.weight)
def forward(self, x):
return self.fc(x) + self.bypass(x)
model2 = Model_bypass()
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)
def adapt_state_dict_missing_param(optimizer, state_dict):
adapted_state_dict = deepcopy(optimizer.state_dict())
# Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
for k, v in state_dict['param_groups'][0].items():
if k not in ['params', 'param_names']:
adapted_state_dict['param_groups'][0][k] = v
lookup_dict = {
'fc.weight': 'fc.weight',
'fc.bias': 'fc.bias',
'bypass.weight': None,
}
clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
for param_id, param_name in zip(
optimizer.state_dict()['param_groups'][0]['params'],
optimizer.state_dict()['param_groups'][0]['param_names']):
name_in_loaded = lookup_dict[param_name]
if name_in_loaded in state_dict['param_groups'][0]['param_names']:
index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
# Copy the state of the corresponding parameter
if id_in_loaded in state_dict['state']:
adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])
return adapted_state_dict
optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict
作为第三个示例,而不是根据参数顺序加载状态(默认方法),可以使用此钩子根据参数名称加载。
def names_matching(optimizer, state_dict):
assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups'])
adapted_state_dict = deepcopy(optimizer.state_dict())
for g_ind in range(len(state_dict['param_groups'])):
assert len(state_dict['param_groups'][g_ind]['params']) == len(
optimizer.state_dict()['param_groups'][g_ind]['params'])
for k, v in state_dict['param_groups'][g_ind].items():
if k not in ['params', 'param_names']:
adapted_state_dict['param_groups'][g_ind][k] = v
for param_id, param_name in zip(
optimizer.state_dict()['param_groups'][g_ind]['params'],
optimizer.state_dict()['param_groups'][g_ind]['param_names']):
index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name)
id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list]
# Copy the state of the corresponding parameter
if id_in_loaded in state_dict['state']:
adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded])
return adapted_state_dict
权重平均(SWA 和 EMA)#
torch.optim.swa_utils.AveragedModel 实现随机权重平均(SWA)和指数移动平均(EMA),torch.optim.swa_utils.SWALR 实现 SWA 学习率调度器,而 torch.optim.swa_utils.update_bn() 是一个用于在训练结束时更新 SWA/EMA 批量归一化统计的实用函数。
SWA 在 Averaging Weights Leads to Wider Optima and Better Generalization 中被提出。
EMA 是一种广泛已知的技术,通过减少所需的权重更新次数来缩短训练时间。它是 Polyak averaging 的一种变体,但使用指数权重而不是迭代之间的相等权重。
构造平均模型#
AveragedModel 类用于计算 SWA 或 EMA 模型的权重。
您可以通过运行以下命令创建一个 SWA 平均模型
>>> averaged_model = AveragedModel(model)
EMA 模型通过指定 multi_avg_fn 参数来构造,如下所示
>>> decay = 0.999
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))
Decay 是一个介于 0 和 1 之间的参数,它控制平均参数衰减的速度。如果未提供给 torch.optim.swa_utils.get_ema_multi_avg_fn(),则默认为 0.999。Decay 值应接近 1.0,因为较小的值可能导致优化收敛问题。
torch.optim.swa_utils.get_ema_multi_avg_fn() 返回一个函数,该函数将以下 EMA 方程应用于权重
其中 alpha 是 EMA 衰减率。
这里的模型 model 可以是任意的 torch.nn.Module 对象。averaged_model 将跟踪 model 参数的运行平均值。要更新这些平均值,您应该在 optimizer.step() 之后使用 update_parameters() 函数。
>>> averaged_model.update_parameters(model)
对于 SWA 和 EMA,这个调用通常在优化器 step() 之后立即完成。在 SWA 的情况下,这通常在训练开始的某些步数内被跳过。
自定义平均策略#
默认情况下,torch.optim.swa_utils.AveragedModel 计算您提供的参数的运行平均值,但您也可以使用 avg_fn 或 multi_avg_fn 参数来定义自定义平均函数。
avg_fn允许定义一个作用于每个参数元组(平均参数,模型参数)的函数,并应返回新的平均参数。multi_avg_fn允许同时定义更高效的操作,这些操作作用于参数列表元组(平均参数列表,模型参数列表),例如使用torch._foreach*函数。此函数必须就地更新平均参数。
在下面的示例中,ema_model 使用 avg_fn 参数计算指数移动平均。
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>> 0.9 * averaged_model_parameter + 0.1 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)
在下面的示例中,ema_model 使用更高效的 multi_avg_fn 参数计算指数移动平均。
>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9))
SWA 学习率调度#
通常,在 SWA 中,学习率被设置为一个较高的常数值。SWALR 是一个学习率调度器,它将学习率退火到一个固定值,然后保持不变。例如,以下代码创建了一个调度器,该调度器在每个参数组内的 5 个 epoch 中将学习率从初始值线性退火到 0.05。
>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>> anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)
您也可以选择使用余弦退火到一个固定值而不是线性退火,方法是将 anneal_strategy="cos" 设置为。
处理批量归一化#
update_bn() 是一个实用函数,允许在训练结束时使用给定的数据加载器 loader 为 SWA 模型计算批量归一化统计信息。
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
update_bn() 将 swa_model 应用于数据加载器中的每个元素,并为模型中的每个批量归一化层计算激活统计信息。
警告
update_bn() 假设数据加载器 loader 中的每个批次是张量或张量列表,其中第一个元素是网络 swa_model 应该应用到的张量。如果您的数据加载器结构不同,您可以通过在数据集的每个元素上使用 swa_model 进行前向传递来更新 swa_model 的批量归一化统计信息。
整合:SWA#
在下面的示例中,swa_model 是累加权重平均值的 SWA 模型。我们总共训练模型 300 个 epoch,并在第 160 个 epoch 切换到 SWA 学习率调度器并开始收集参数的 SWA 平均值。
>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> if epoch > swa_start:
>>> swa_model.update_parameters(model)
>>> swa_scheduler.step()
>>> else:
>>> scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)
整合:EMA#
在下面的示例中,ema_model 是 EMA 模型,它以 0.999 的衰减率累加权重的指数衰减平均值。我们总共训练模型 300 个 epoch,并立即开始收集 EMA 平均值。
>>> loader, optimizer, model, loss_fn = ...
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \
>>> multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
>>>
>>> for epoch in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> ema_model.update_parameters(model)
>>>
>>> # Update bn statistics for the ema_model at the end
>>> torch.optim.swa_utils.update_bn(loader, ema_model)
>>> # Use ema_model to make predictions on test data
>>> preds = ema_model(test_input)
为随机权重平均 (SWA) 和指数移动平均 (EMA) 实现平均模型。 |
|
将每个参数组的学习率退火到一个固定值。 |
- torch.optim.swa_utils.update_bn(loader, model, device=None)[source]#
更新模型中的 BatchNorm running_mean、running_var 缓冲区。
它会遍历 loader 中的数据一次,以估计模型中 BatchNorm 层的激活统计信息。
- 参数:
loader (torch.utils.data.DataLoader) – 用于计算激活统计信息的数据集加载器。每个数据批次应该是张量,或者是一个列表/元组,其中第一个元素是包含数据的张量。
model (torch.nn.Module) – 我们希望更新 BatchNorm 统计信息的模型。
device (torch.device, optional) – 如果设置,数据将在传递到
model之前传输到device。
示例
>>> loader, model = ... >>> torch.optim.swa_utils.update_bn(loader, model)
注意
update_bn 工具假设
loader中的每个数据批次要么是张量,要么是张量列表或元组;在后一种情况下,假设model.forward()应该在列表或元组的第一个元素上调用,该元素对应于数据批次。