注意
跳转至页面底部 下载完整示例代码。
在 PyTorch 中使用来自不同模型的参数进行模型热启动#
创建日期:2020年4月20日 | 最后更新:2024年8月27日 | 最后验证:2024年11月5日
在进行迁移学习或训练新的复杂模型时,部分加载模型或加载部分模型是常见的场景。利用已训练的参数(即使只有一部分可用),将有助于热启动训练过程,并有望帮助模型比从头开始训练收敛得更快。
简介#
无论你是从缺少某些键的部分 state_dict 加载,还是加载包含比目标模型更多键的 state_dict,你都可以通过将 load_state_dict() 函数中的 strict 参数设置为 False 来忽略不匹配的键。在本教程中,我们将尝试使用不同模型的参数来热启动一个模型。
设置#
在开始之前,如果尚未安装 torch,我们需要安装它。
pip install torch
步骤#
导入加载数据所需的所有库
定义并初始化神经网络 A 和 B
保存模型 A
加载到模型 B 中
1. 导入加载数据所需的库#
在本教程中,我们将使用 torch 及其子模块 torch.nn 和 torch.optim。
import torch
import torch.nn as nn
import torch.optim as optim
2. 定义并初始化神经网络 A 和 B#
为了示例起见,我们将创建一个用于图像训练的神经网络。欲了解更多信息,请参阅“定义神经网络”教程。我们将创建两个神经网络,以便演示如何将 A 类型的参数加载到 B 类型模型中。
class NetA(nn.Module):
def __init__(self):
super(NetA, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netA = NetA()
class NetB(nn.Module):
def __init__(self):
super(NetB, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netB = NetB()
3. 保存模型 A#
# Specify a path to save to
PATH = "model.pt"
torch.save(netA.state_dict(), PATH)
4. 加载到模型 B 中#
如果你想将参数从一个层加载到另一个层,但某些键不匹配,只需更改你正在加载的 state_dict 中的参数键名称,使其与目标模型中的键相匹配即可。
netB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)
你可以看到所有键都已成功匹配!
恭喜!你已经成功在 PyTorch 中使用来自不同模型的参数对模型进行了热启动。
了解更多#
查看这些其他秘籍以继续您的学习