评价此页

torch.optim#

创建于: 2025年6月13日 | 最后更新于: 2025年8月24日

torch.optim 是一个实现了各种优化算法的包。

大多数常用方法已得到支持,并且接口足够通用,以便将来可以轻松集成更复杂的方法。

如何使用优化器#

要使用 torch.optim,您需要构造一个优化器对象,该对象将保存当前状态并根据计算出的梯度更新参数。

构造它#

要构造一个 Optimizer,您需要为其提供一个包含要优化的参数(所有参数都应为 Parameter)或命名参数((str, Parameter) 的元组)的可迭代对象。然后,您可以指定特定于优化器的选项,例如学习率、权重衰减等。

示例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

命名参数示例

optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001)

按参数分组的选项#

Optimizer 还支持按参数分组指定选项。为此,请不要传递 Variable 的可迭代对象,而是传递 dict 的可迭代对象。每个字典将定义一个单独的参数组,并应包含一个 params 键,其中包含属于该组的参数列表。其他键应与优化器接受的关键字参数匹配,并将用作此组的优化选项。

例如,当一个人想要为每个层指定不同的学习率时,这非常有用。

optim.SGD([
    {'params': model.base.parameters(), 'lr': 1e-2},
    {'params': model.classifier.parameters()}
], lr=1e-3, momentum=0.9)

optim.SGD([
    {'params': model.base.named_parameters(), 'lr': 1e-2},
    {'params': model.classifier.named_parameters()}
], lr=1e-3, momentum=0.9)

这意味着 model.base 的参数将使用 1e-2 的学习率,而 model.classifier 的参数将保持默认学习率 1e-3。最后,将对所有参数使用 0.9 的动量。

注意

您仍然可以传递选项作为关键字参数。未在组中覆盖的选项将使用它们作为默认值。当您只想更改一个选项,同时保持参数组之间的所有其他选项一致时,这很有用。

另请考虑以下与参数区分惩罚相关的示例。请记住,parameters() 返回一个可迭代对象,其中包含所有可学习的参数,包括可能需要区分惩罚的偏置和其他参数。为了解决这个问题,可以为每个参数组指定单独的惩罚权重。

bias_params = [p for name, p in self.named_parameters() if 'bias' in name]
others = [p for name, p in self.named_parameters() if 'bias' not in name]

optim.SGD([
    {'params': others},
    {'params': bias_params, 'weight_decay': 0}
], weight_decay=1e-2, lr=1e-2)

这样,偏置项与非偏置项分开,并且为偏置项专门设置了 0weight_decay,以避免对该组进行任何惩罚。

执行优化步骤#

所有优化器都实现了一个 step() 方法,用于更新参数。它有两种用法:

optimizer.step()#

这是大多数优化器支持的简化版本。在计算出梯度后(例如使用 backward()),即可调用该函数。

示例

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

optimizer.step(closure)#

某些优化算法,例如共轭梯度法和 L-BFGS,需要多次重新评估函数,因此您必须传递一个闭包,以便它们可以重新计算模型。闭包应清除梯度,计算损失并返回它。

示例

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

基类#

class torch.optim.Optimizer(params, defaults)[source]#

所有优化器的基类。

警告

需要将参数指定为具有确定性顺序的对象,并且该顺序在不同运行之间是一致的。不满足这些属性的对象示例包括集合以及字典值的迭代器。

参数
  • params (iterable) – 一个 torch.Tensordict 的可迭代对象。指定要优化的 Tensor。

  • defaults (dict[str, Any]) – (dict):一个包含优化选项默认值的字典(当参数组未指定时使用)。

Optimizer.add_param_group

向优化器的 param_groups 添加一个参数组。

Optimizer.load_state_dict

加载优化器状态。

Optimizer.register_load_state_dict_pre_hook

注册一个 load_state_dict 前置钩子,该钩子将在调用 load_state_dict() 之前调用。它应该具有以下签名:。

Optimizer.register_load_state_dict_post_hook

注册一个 load_state_dict 后置钩子,该钩子将在调用 load_state_dict() 之后调用。它应该具有以下签名:。

Optimizer.state_dict

将优化器的状态作为 dict 返回。

Optimizer.register_state_dict_pre_hook

注册一个 state_dict 前置钩子,该钩子将在调用 state_dict() 之前调用。

Optimizer.register_state_dict_post_hook

注册一个 state_dict 后置钩子,该钩子将在调用 state_dict() 之后调用。

Optimizer.step

执行一次优化步骤来更新参数。

Optimizer.register_step_pre_hook

注册一个优化器步骤预钩子,它将在优化器步骤之前被调用。

Optimizer.register_step_post_hook

注册一个优化器步骤后钩子,它将在优化器步骤之后被调用。

Optimizer.zero_grad

重置所有已优化 torch.Tensor 的梯度。

算法#

Adadelta

实现了 Adadelta 算法。

Adafactor

实现了 Adafactor 算法。

Adagrad

实现了 Adagrad 算法。

Adam

实现了 Adam 算法。

AdamW

实现了 AdamW 算法,其中权重衰减不累积到动量或方差中。

SparseAdam

SparseAdam 实现了一个 Adam 算法的掩码版本,适用于稀疏梯度。

Adamax

实现了 Adamax 算法(基于无穷范数的 Adam 变体)。

ASGD

实现了平均随机梯度下降。

LBFGS

实现了 L-BFGS 算法。

Muon

实现了 Muon 算法。

NAdam

实现了 NAdam 算法。

RAdam

实现了 RAdam 算法。

RMSprop

实现了 RMSprop 算法。

Rprop

实现了弹性反向传播算法。

SGD

实现了随机梯度下降(可选带动量)。

我们的许多算法都有各种针对性能、可读性和/或通用性进行优化的实现,因此,如果我们没有指定任何特定的实现,我们会尝试默认使用当前设备上通常最快的实现。

我们有 3 个主要的实现类别:for-loop、foreach(多张量)和 fused。最直接的实现是在具有大型计算块的参数上使用 for-loop。For-loop 通常比我们的 foreach 实现慢,后者将参数合并到一个多张量中,一次性运行大型计算块,从而节省了许多顺序内核调用。我们的一些优化器甚至有更快的 fused 实现,它们将大型计算块融合到一个内核中。我们可以将 foreach 实现视为水平融合,而 fused 实现则在此基础上进行垂直融合。

通常,3 种实现的性能顺序是 fused > foreach > for-loop。因此,在适用时,我们默认使用 foreach 而不是 for-loop。适用意味着 foreach 实现可用,用户没有指定任何特定于实现的 kwargs(例如,fused、foreach、differentiable),并且所有张量都是原生的。请注意,虽然 fused 应该比 foreach 更快,但这些实现较新,我们希望在全面启用之前让它们有更多的时间来完善。我们在下表总结了每种实现的稳定性状态,欢迎您尝试!

下表显示了每种算法可用的和默认的实现。

算法

默认

有 foreach 吗?

有 fused 吗?

Adadelta

foreach

Adafactor

for-loop

Adagrad

foreach

是(仅限 CPU)

Adam

foreach

AdamW

foreach

SparseAdam

for-loop

Adamax

foreach

ASGD

foreach

LBFGS

for-loop

Muon

for-loop

NAdam

foreach

RAdam

foreach

RMSprop

foreach

Rprop

foreach

SGD

foreach

下表显示了 fused 实现的稳定性状态。

算法

CPU

CUDA

MPS

Adadelta

不支持

不支持

不支持

Adafactor

不支持

不支持

不支持

Adagrad

beta

不支持

不支持

Adam

beta

稳定

beta

AdamW

beta

稳定

beta

SparseAdam

不支持

不支持

不支持

Adamax

不支持

不支持

不支持

ASGD

不支持

不支持

不支持

LBFGS

不支持

不支持

不支持

Muon

不支持

不支持

不支持

NAdam

不支持

不支持

不支持

RAdam

不支持

不支持

不支持

RMSprop

不支持

不支持

不支持

Rprop

不支持

不支持

不支持

SGD

beta

beta

beta

如何调整学习率#

torch.optim.lr_scheduler.LRScheduler 提供了几种根据 epoch 数量调整学习率的方法。torch.optim.lr_scheduler.ReduceLROnPlateau 允许根据一些验证度量动态降低学习率。

学习率调度应在优化器更新后应用;例如,您的代码应这样编写:

示例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler.step()

大多数学习率调度器可以连续调用(也称为链式调度器)。其结果是每个调度器按顺序应用于前一个调度器获得学习率。

示例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler1.step()
    scheduler2.step()

在文档的许多地方,我们将使用以下模板来引用调度器算法。

>>> scheduler = ...
>>> for epoch in range(100):
>>>     train(...)
>>>     validate(...)
>>>     scheduler.step()

警告

在 PyTorch 1.1.0 之前,学习率调度器应在优化器更新之前调用;1.1.0 以向后不兼容的方式改变了此行为。如果您在更新到 PyTorch 1.1.0 后无法重现结果,请检查您是否在错误的时间调用了 scheduler.step()

lr_scheduler.LRScheduler

在优化过程中调整学习率。

lr_scheduler.LambdaLR

设置初始学习率。

lr_scheduler.MultiplicativeLR

通过指定函数中的因子来乘以每个参数组的学习率。

lr_scheduler.StepLR

每 step_size 个 epoch,将每个参数组的学习率按 gamma 衰减。

lr_scheduler.MultiStepLR

当 epoch 数量达到 milestones 之一时,将每个参数组的学习率按 gamma 衰减一次。

lr_scheduler.ConstantLR

将每个参数组的学习率乘以一个小的常数因子。

lr_scheduler.LinearLR

通过线性改变小的乘法因子来衰减每个参数组的学习率。

lr_scheduler.ExponentialLR

每 epoch,将每个参数组的学习率按 gamma 衰减。

lr_scheduler.PolynomialLR

使用给定 total_iters 中的多项式函数来衰减每个参数组的学习率。

lr_scheduler.CosineAnnealingLR

使用余弦退火调度设置每个参数组的学习率。

lr_scheduler.ChainedScheduler

链接一系列学习率调度器。

lr_scheduler.SequentialLR

包含一系列调度器,这些调度器预计在优化过程中按顺序调用。

lr_scheduler.ReduceLROnPlateau

当某个度量停止改进时,降低学习率。

lr_scheduler.CyclicLR

根据循环学习率策略(CLR)设置每个参数组的学习率。

lr_scheduler.OneCycleLR

根据 1cycle 学习率策略设置每个参数组的学习率。

lr_scheduler.CosineAnnealingWarmRestarts

使用余弦退火调度设置每个参数组的学习率。

如何利用命名参数加载优化器 state_dict#

函数 load_state_dict() 会存储从加载的 state_dict 中可选的 param_names 内容(如果存在)。但是,加载优化器状态的过程不受影响,因为参数的顺序很重要,可以保持兼容性(以防顺序不同)。要利用加载的 state_dict 中的已加载参数名称,需要根据期望的行为实现自定义 register_load_state_dict_pre_hook

这在模型架构发生变化但权重和优化器状态需要保持不变的情况下很有用。以下示例演示了如何实现此自定义。

示例

class OneLayerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 4)

    def forward(self, x):
        return self.fc(x)

model = OneLayerModel()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)

假设 model 实现了一个专家(MoE),我们想复制它并在两个专家上恢复训练,这两个专家都以与 fc 层相同的方式初始化。对于下面的 model2,我们创建了两个与 fc 相同的层,并通过将 model 的权重和优化器状态加载到 model2fc1fc2 中来恢复训练(并相应地调整它们)。

class TwoLayerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(3, 4)

    def forward(self, x):
        return (self.fc1(x) + self.fc2(x)) / 2

model2 = TwoLayerModel()
# adapt and load model weights..
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)

为了加载 optimizer2 的 state dict,同时加载先前优化器的 state dict,以便 fc1fc2 都将用 fc 优化器状态的副本进行初始化(以便从 fc 继续训练每个层),我们可以使用以下钩子:

def adapt_state_dict_ids(optimizer, state_dict):
    adapted_state_dict = deepcopy(optimizer.state_dict())
    # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
    for k, v in state_dict['param_groups'][0].items():
        if k not in ['params', 'param_names']:
            adapted_state_dict['param_groups'][0][k] = v

    lookup_dict = {
        'fc1.weight': 'fc.weight',
        'fc1.bias': 'fc.bias',
        'fc2.weight': 'fc.weight',
        'fc2.bias': 'fc.bias'
    }
    clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
    for param_id, param_name in zip(
            optimizer.state_dict()['param_groups'][0]['params'],
            optimizer.state_dict()['param_groups'][0]['param_names']):
        name_in_loaded = lookup_dict[param_name]
        index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
        id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
        # Copy the state of the corresponding parameter
        if id_in_loaded in state_dict['state']:
            adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict

这确保了在模型加载期间将使用适应的 state_dict,其中包含 model2 的层的正确状态。请注意,此代码是专门为此示例设计的(例如,假设只有一个参数组),其他情况可能需要不同的调整。

以下示例显示了如何在模型结构更改时处理加载的 state dict 中缺失的参数。Model_bypass 添加了一个新的 bypass 层,该层在原始 Model1 中不存在。为了恢复训练,使用自定义的 adapt_state_dict_missing_param 钩子来适应优化器的 state_dict,确保现有参数映射正确,而缺失的参数(如示例中初始化的 bypass 层)保持不变。这种方法使得即使模型发生变化,也能平滑地加载和恢复优化器状态。新添加的 bypass 层将从头开始训练。

class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(5, 5)

    def forward(self, x):
        return self.fc(x) + x


model = Model1()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)

class Model_bypass(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(5, 5)
        self.bypass = nn.Linear(5, 5, bias=False)
        torch.nn.init.eye_(self.bypass.weight)

    def forward(self, x):
        return self.fc(x) + self.bypass(x)

model2 = Model_bypass()
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)

def adapt_state_dict_missing_param(optimizer, state_dict):
    adapted_state_dict = deepcopy(optimizer.state_dict())
    # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
    for k, v in state_dict['param_groups'][0].items():
        if k not in ['params', 'param_names']:
            adapted_state_dict['param_groups'][0][k] = v

    lookup_dict = {
        'fc.weight': 'fc.weight',
        'fc.bias': 'fc.bias',
        'bypass.weight': None,
    }

    clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
    for param_id, param_name in zip(
            optimizer.state_dict()['param_groups'][0]['params'],
            optimizer.state_dict()['param_groups'][0]['param_names']):
        name_in_loaded = lookup_dict[param_name]
        if name_in_loaded in state_dict['param_groups'][0]['param_names']:
            index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
            id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
            # Copy the state of the corresponding parameter
            if id_in_loaded in state_dict['state']:
                adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict

作为第三个示例,该钩子可以用于根据参数的名称加载,而不是根据参数的顺序(默认方法)。

def names_matching(optimizer, state_dict):
    assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups'])
    adapted_state_dict = deepcopy(optimizer.state_dict())
    for g_ind in range(len(state_dict['param_groups'])):
        assert len(state_dict['param_groups'][g_ind]['params']) == len(
            optimizer.state_dict()['param_groups'][g_ind]['params'])

        for k, v in state_dict['param_groups'][g_ind].items():
            if k not in ['params', 'param_names']:
                adapted_state_dict['param_groups'][g_ind][k] = v

        for param_id, param_name in zip(
                optimizer.state_dict()['param_groups'][g_ind]['params'],
                optimizer.state_dict()['param_groups'][g_ind]['param_names']):
            index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name)
            id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list]
            # Copy the state of the corresponding parameter
            if id_in_loaded in state_dict['state']:
                adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

权重平均(SWA 和 EMA)#

torch.optim.swa_utils.AveragedModel 实现随机权重平均(SWA)和指数移动平均(EMA),torch.optim.swa_utils.SWALR 实现 SWA 学习率调度器,而 torch.optim.swa_utils.update_bn() 是一个在训练结束时用于更新 SWA/EMA 批量归一化统计量的实用函数。

SWA 在 Averaging Weights Leads to Wider Optima and Better Generalization 中被提出。

EMA 是一种广泛用于减少训练时间的技术,通过减少所需的权重更新次数。它是 Polyak averaging 的一个变种,但使用指数权重而不是迭代之间的相等权重。

构造平均模型#

AveragedModel 类用于计算 SWA 或 EMA 模型的权重。

您可以通过运行以下命令创建一个 SWA 平均模型:

>>> averaged_model = AveragedModel(model)

EMA 模型通过将 `multi_avg_fn` 参数指定为以下方式来构造:

>>> decay = 0.999
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))

衰减是一个介于 0 和 1 之间的参数,它控制平均参数衰减的速度。如果未提供给 torch.optim.swa_utils.get_ema_multi_avg_fn(),则默认值为 0.999。衰减值应接近 1.0,因为较小的值可能导致优化收敛问题。

torch.optim.swa_utils.get_ema_multi_avg_fn() 返回一个函数,该函数将以下 EMA 方程应用于权重:

Wt+1EMA=αWtEMA+(1α)WtmodelW^\textrm{EMA}_{t+1} = \alpha W^\textrm{EMA}_{t} + (1 - \alpha) W^\textrm{model}_t

其中 alpha 是 EMA 衰减。

这里的 model model 可以是任意的 torch.nn.Module 对象。averaged_model 将跟踪 model 参数的运行平均值。要更新这些平均值,您应该在 optimizer.step() 之后使用 update_parameters() 函数。

>>> averaged_model.update_parameters(model)

对于 SWA 和 EMA,此调用通常在优化器 step() 之后不久进行。在 SWA 的情况下,这通常在训练开始的某些步数内跳过。

自定义平均策略#

默认情况下,torch.optim.swa_utils.AveragedModel 计算您提供的参数的运行平均值,但您也可以使用 `avg_fn` 或 `multi_avg_fn` 参数来自定义平均函数。

  • avg_fn 允许定义一个对每个参数元组(平均参数,模型参数)进行操作的函数,并应返回新的平均参数。

  • multi_avg_fn 允许同时定义更高效的操作,这些操作作用于参数列表元组(平均参数列表,模型参数列表),例如使用 torch._foreach* 函数。此函数必须就地更新平均参数。

在以下示例中,ema_model 使用 avg_fn 参数计算指数移动平均。

>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>>         0.9 * averaged_model_parameter + 0.1 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

在以下示例中,ema_model 使用更高效的 multi_avg_fn 参数计算指数移动平均。

>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9))

SWA 学习率调度#

通常,在 SWA 中,学习率被设置为一个高常量值。SWALR 是一个学习率调度器,它将学习率衰减到一个固定值,然后保持不变。例如,以下代码创建一个调度器,该调度器在每个参数组的 5 个 epoch 内将学习率从初始值线性衰减到 0.05。

>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>>         anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

您还可以通过设置 anneal_strategy="cos" 来使用余弦退火到一个固定值,而不是线性退火。

处理批量归一化#

update_bn() 是一个实用函数,它允许在训练结束时在给定的数据加载器 loader 上计算 SWA 模型的批量归一化统计量。

>>> torch.optim.swa_utils.update_bn(loader, swa_model)

update_bn()swa_model 应用于数据加载器中的每个元素,并计算模型中每个批量归一化层的激活统计量。

警告

update_bn() 假设数据加载器 loader 中的每个批次是张量,或者是一个张量列表/元组,其中第一个元素是网络 swa_model 应该应用的张量。如果您的数据加载器结构不同,您可以通过在数据集的每个元素上进行前向传递(使用 swa_model)来更新 swa_model 的批量归一化统计量。

总而言之:SWA#

在下面的示例中,swa_model 是累积权重平均值的 SWA 模型。我们将模型训练总共 300 个 epoch,并在 epoch 160 时切换到 SWA 学习率调度器并开始收集参数的 SWA 平均值。

>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>       if epoch > swa_start:
>>>           swa_model.update_parameters(model)
>>>           swa_scheduler.step()
>>>       else:
>>>           scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)

总而言之:EMA#

在下面的示例中,ema_model 是 EMA 模型,它以 0.999 的衰减率累积权重的指数衰减平均值。我们将模型训练总共 300 个 epoch,并立即开始收集 EMA 平均值。

>>> loader, optimizer, model, loss_fn = ...
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \
>>>             multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>           ema_model.update_parameters(model)
>>>
>>> # Update bn statistics for the ema_model at the end
>>> torch.optim.swa_utils.update_bn(loader, ema_model)
>>> # Use ema_model to make predictions on test data
>>> preds = ema_model(test_input)

swa_utils.AveragedModel

为随机权重平均(SWA)和指数移动平均(EMA)实现平均模型。

swa_utils.SWALR

将每个参数组的学习率衰减到一个固定值。

torch.optim.swa_utils.get_ema_multi_avg_fn(decay=0.999)[source]#

获取跨多个参数应用指数移动平均(EMA)的函数。

torch.optim.swa_utils.update_bn(loader, model, device=None)[source]#

更新模型中的 BatchNorm running_mean、running_var 缓冲区。

它会遍历 loader 中的数据一次,以估算模型中 BatchNorm 层的激活统计量。

参数
  • loader (torch.utils.data.DataLoader) – 用于计算激活统计量的数据集加载器。每个数据批次都应该是张量,或者是一个列表/元组,其中第一个元素是包含数据的张量。

  • model (torch.nn.Module) – 我们要为其更新 BatchNorm 统计量的模型。

  • device (torch.device, optional) – 如果设置,数据将在传入 model 之前传输到 device

示例

>>> loader, model = ...
>>> torch.optim.swa_utils.update_bn(loader, model)

注意

cite>update_bn 实用函数假定 loader 中的每个数据批次都是张量,或者是张量的列表或元组;在后一种情况下,假定 model.forward() 应该在与数据批次对应的列表或元组的第一个元素上调用。