评价此页

在 PyTorch 中使用不同模型的参数暖启动模型#

创建日期:2020 年 4 月 20 日 | 最后更新:2024 年 8 月 27 日 | 最后验证:2024 年 11 月 5 日

部分加载模型或加载一个不完整的模型是在迁移学习或训练一个复杂的新模型时常见的场景。利用已训练的参数,即使只有一部分可用,也能帮助暖启动训练过程,并有望帮助你的模型比从头开始训练收敛得更快。

简介#

无论你是从一个缺少某些键的、不完整的 state_dict 加载,还是将一个键比你正在加载的模型更多的 state_dict 加载到一个模型中,你都可以在 load_state_dict() 函数中将 strict 参数设置为 False 来忽略不匹配的键。在本指南中,我们将尝试使用不同模型的参数来暖启动一个模型。

设置#

在开始之前,如果尚未安装 torch,我们需要安装它。

pip install torch

步骤#

  1. 导入加载数据所需的所有库

  2. 定义并初始化神经网络 A 和 B

  3. 保存模型 A

  4. 加载到模型 B

1. 导入加载数据所需的库#

在本指南中,我们将使用 torch 及其子模块 torch.nntorch.optim

import torch
import torch.nn as nn
import torch.optim as optim

2. 定义并初始化神经网络 A 和 B#

为方便举例,我们将创建一个用于训练图像的神经网络。如需了解更多信息,请参阅定义神经网络指南。我们将创建两个神经网络,以便将 A 类型的参数加载到 B 类型中。

class NetA(nn.Module):
    def __init__(self):
        super(NetA, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netA = NetA()

class NetB(nn.Module):
    def __init__(self):
        super(NetB, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netB = NetB()

3. 保存模型 A#

# Specify a path to save to
PATH = "model.pt"

torch.save(netA.state_dict(), PATH)

4. 加载到模型 B#

如果你想将参数从一个层加载到另一个层,但某些键不匹配,只需更改你正在加载的 state_dict 中的参数键名称,使其与你正在加载的模型中的键匹配。

netB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)

可以看到所有键都成功匹配了!

恭喜!你已成功使用 PyTorch 中不同模型的参数暖启动了一个模型。

了解更多#

查看这些其他秘籍以继续您的学习