评价此页

在 PyTorch 中使用来自不同模型的参数进行模型热启动#

创建日期:2020年4月20日 | 最后更新:2024年8月27日 | 最后验证:2024年11月5日

在进行迁移学习或训练新的复杂模型时,部分加载模型或加载部分模型是常见的场景。利用已训练的参数(即使只有一部分可用),将有助于热启动训练过程,并有望帮助模型比从头开始训练收敛得更快。

简介#

无论你是从缺少某些键的部分 state_dict 加载,还是加载包含比目标模型更多键的 state_dict,你都可以通过将 load_state_dict() 函数中的 strict 参数设置为 False 来忽略不匹配的键。在本教程中,我们将尝试使用不同模型的参数来热启动一个模型。

设置#

在开始之前,如果尚未安装 torch,我们需要安装它。

pip install torch

步骤#

  1. 导入加载数据所需的所有库

  2. 定义并初始化神经网络 A 和 B

  3. 保存模型 A

  4. 加载到模型 B 中

1. 导入加载数据所需的库#

在本教程中,我们将使用 torch 及其子模块 torch.nntorch.optim

import torch
import torch.nn as nn
import torch.optim as optim

2. 定义并初始化神经网络 A 和 B#

为了示例起见,我们将创建一个用于图像训练的神经网络。欲了解更多信息,请参阅“定义神经网络”教程。我们将创建两个神经网络,以便演示如何将 A 类型的参数加载到 B 类型模型中。

class NetA(nn.Module):
    def __init__(self):
        super(NetA, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netA = NetA()

class NetB(nn.Module):
    def __init__(self):
        super(NetB, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netB = NetB()

3. 保存模型 A#

# Specify a path to save to
PATH = "model.pt"

torch.save(netA.state_dict(), PATH)

4. 加载到模型 B 中#

如果你想将参数从一个层加载到另一个层,但某些键不匹配,只需更改你正在加载的 state_dict 中的参数键名称,使其与目标模型中的键相匹配即可。

netB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)

你可以看到所有键都已成功匹配!

恭喜!你已经成功在 PyTorch 中使用来自不同模型的参数对模型进行了热启动。

了解更多#

查看这些其他秘籍以继续您的学习