评价此页

计算机视觉的迁移学习教程#

创建日期:2017年3月24日 | 最后更新:2025年1月27日 | 最后验证:2024年11月5日

作者: Sasank Chilamkurthy

在本教程中,您将学习如何使用迁移学习训练用于图像分类的卷积神经网络。您可以在 cs231n 笔记 中阅读更多关于迁移学习的信息

引用这些笔记:

在实践中,很少有人从头开始(随机初始化)训练整个卷积网络,因为拥有足够大的数据集相对罕见。相反,通常是在一个非常大的数据集(例如 ImageNet,包含 120 万张图像和 1000 个类别)上预训练一个 ConvNet,然后将该 ConvNet 用作初始化或感兴趣任务的固定特征提取器。

这两种主要的迁移学习场景如下:

  • 微调 ConvNet:我们不是随机初始化网络,而是使用预训练的网络(例如在 ImageNet 1000 数据集上训练的网络)初始化网络。其余训练照常进行。

  • 将 ConvNet 作为固定特征提取器:在这里,我们将冻结网络中除最终全连接层之外的所有权重。这个最后的全连接层被替换为一个具有随机权重的新层,并且只训练这一层。

# License: BSD
# Author: Sasank Chilamkurthy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory

cudnn.benchmark = True
plt.ion()   # interactive mode
<contextlib.ExitStack object at 0x7f041811c4f0>

加载数据#

我们将使用 torchvision 和 torch.utils.data 包来加载数据。

我们今天要解决的问题是训练一个模型来分类**蚂蚁**和**蜜蜂**。我们有大约 120 张蚂蚁和蜜蜂的训练图像。每个类别有 75 张验证图像。通常,如果从头开始训练,这是一个非常小的数据集,难以泛化。由于我们使用迁移学习,我们应该能够很好地泛化。

此数据集是 ImageNet 的一个非常小的子集。

注意

此处 下载数据并将其解压到当前目录。

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# We want to be able to train our model on an `accelerator <https://pytorch.ac.cn/docs/stable/torch.html#accelerators>`__
# such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")
Using cuda device

可视化少量图像#

让我们可视化一些训练图像,以便了解数据增强。

def imshow(inp, title=None):
    """Display image for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])
['bees', 'bees', 'bees', 'ants']

训练模型#

现在,让我们编写一个通用函数来训练模型。在这里,我们将演示

  • 调度学习率

  • 保存最佳模型

在下文中,参数 scheduler 是来自 torch.optim.lr_scheduler 的 LR 调度器对象。

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')

        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()   # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                if phase == 'train':
                    scheduler.step()

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                # deep copy the model
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        # load best model weights
        model.load_state_dict(torch.load(best_model_params_path, weights_only=True))
    return model

可视化模型预测#

显示少量图像预测的通用函数

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

微调 ConvNet#

加载预训练模型并重置最终全连接层。

model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

  0%|          | 0.00/44.7M [00:00<?, ?B/s]
 91%|█████████ | 40.5M/44.7M [00:00<00:00, 424MB/s]
100%|██████████| 44.7M/44.7M [00:00<00:00, 425MB/s]

训练和评估#

在 CPU 上大约需要 15-25 分钟。但在 GPU 上,不到一分钟。

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)
Epoch 0/24
----------
train Loss: 0.5152 Acc: 0.7213
val Loss: 0.1994 Acc: 0.9085

Epoch 1/24
----------
train Loss: 0.7914 Acc: 0.6967
val Loss: 0.3712 Acc: 0.8824

Epoch 2/24
----------
train Loss: 0.5874 Acc: 0.8115
val Loss: 0.5923 Acc: 0.8235

Epoch 3/24
----------
train Loss: 0.5997 Acc: 0.7295
val Loss: 0.2508 Acc: 0.9085

Epoch 4/24
----------
train Loss: 0.6454 Acc: 0.7705
val Loss: 0.4392 Acc: 0.8431

Epoch 5/24
----------
train Loss: 0.3868 Acc: 0.8197
val Loss: 0.2327 Acc: 0.9020

Epoch 6/24
----------
train Loss: 0.4876 Acc: 0.8238
val Loss: 0.2429 Acc: 0.8889

Epoch 7/24
----------
train Loss: 0.2886 Acc: 0.8893
val Loss: 0.2705 Acc: 0.8758

Epoch 8/24
----------
train Loss: 0.2296 Acc: 0.8975
val Loss: 0.2613 Acc: 0.8954

Epoch 9/24
----------
train Loss: 0.3337 Acc: 0.8730
val Loss: 0.2036 Acc: 0.9281

Epoch 10/24
----------
train Loss: 0.3590 Acc: 0.8361
val Loss: 0.2013 Acc: 0.9412

Epoch 11/24
----------
train Loss: 0.2668 Acc: 0.8975
val Loss: 0.1857 Acc: 0.9412

Epoch 12/24
----------
train Loss: 0.2810 Acc: 0.8934
val Loss: 0.1941 Acc: 0.9281

Epoch 13/24
----------
train Loss: 0.2550 Acc: 0.9139
val Loss: 0.2215 Acc: 0.9020

Epoch 14/24
----------
train Loss: 0.2253 Acc: 0.8893
val Loss: 0.1996 Acc: 0.9281

Epoch 15/24
----------
train Loss: 0.2971 Acc: 0.8443
val Loss: 0.1992 Acc: 0.9346

Epoch 16/24
----------
train Loss: 0.2413 Acc: 0.8893
val Loss: 0.2070 Acc: 0.9281

Epoch 17/24
----------
train Loss: 0.2938 Acc: 0.8607
val Loss: 0.2347 Acc: 0.8758

Epoch 18/24
----------
train Loss: 0.3521 Acc: 0.8361
val Loss: 0.2288 Acc: 0.9412

Epoch 19/24
----------
train Loss: 0.3839 Acc: 0.8279
val Loss: 0.2003 Acc: 0.9346

Epoch 20/24
----------
train Loss: 0.3309 Acc: 0.8607
val Loss: 0.1932 Acc: 0.9281

Epoch 21/24
----------
train Loss: 0.2534 Acc: 0.9057
val Loss: 0.1914 Acc: 0.9346

Epoch 22/24
----------
train Loss: 0.2178 Acc: 0.9180
val Loss: 0.2271 Acc: 0.9150

Epoch 23/24
----------
train Loss: 0.3389 Acc: 0.8484
val Loss: 0.1903 Acc: 0.9216

Epoch 24/24
----------
train Loss: 0.2731 Acc: 0.8689
val Loss: 0.2015 Acc: 0.9216

Training complete in 0m 34s
Best val Acc: 0.941176
visualize_model(model_ft)
predicted: ants, predicted: bees, predicted: ants, predicted: ants, predicted: bees, predicted: ants

将 ConvNet 作为固定特征提取器#

在这里,我们需要冻结除了最后一层之外的所有网络。我们需要将 requires_grad = False 设置为冻结参数,这样在 backward() 中就不会计算梯度。

您可以在 此处 的文档中阅读更多相关信息。

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

训练和评估#

在 CPU 上,这将比上一个场景花费大约一半的时间。这是预期的,因为对于大部分网络不需要计算梯度。但是,需要计算前向传播。

model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)
Epoch 0/24
----------
train Loss: 0.5645 Acc: 0.7172
val Loss: 0.3859 Acc: 0.8235

Epoch 1/24
----------
train Loss: 0.5759 Acc: 0.7500
val Loss: 0.2375 Acc: 0.9020

Epoch 2/24
----------
train Loss: 0.4996 Acc: 0.7746
val Loss: 0.1874 Acc: 0.9346

Epoch 3/24
----------
train Loss: 0.4633 Acc: 0.7951
val Loss: 0.1806 Acc: 0.9542

Epoch 4/24
----------
train Loss: 0.5829 Acc: 0.7705
val Loss: 0.1843 Acc: 0.9281

Epoch 5/24
----------
train Loss: 0.6413 Acc: 0.7500
val Loss: 0.2893 Acc: 0.8758

Epoch 6/24
----------
train Loss: 0.4166 Acc: 0.8443
val Loss: 0.1735 Acc: 0.9542

Epoch 7/24
----------
train Loss: 0.3167 Acc: 0.8443
val Loss: 0.1765 Acc: 0.9412

Epoch 8/24
----------
train Loss: 0.2941 Acc: 0.8525
val Loss: 0.1759 Acc: 0.9542

Epoch 9/24
----------
train Loss: 0.3958 Acc: 0.8238
val Loss: 0.1921 Acc: 0.9346

Epoch 10/24
----------
train Loss: 0.4026 Acc: 0.8279
val Loss: 0.1733 Acc: 0.9281

Epoch 11/24
----------
train Loss: 0.2880 Acc: 0.8975
val Loss: 0.1600 Acc: 0.9608

Epoch 12/24
----------
train Loss: 0.4121 Acc: 0.8279
val Loss: 0.1708 Acc: 0.9542

Epoch 13/24
----------
train Loss: 0.3426 Acc: 0.8566
val Loss: 0.1870 Acc: 0.9346

Epoch 14/24
----------
train Loss: 0.3509 Acc: 0.8402
val Loss: 0.1712 Acc: 0.9477

Epoch 15/24
----------
train Loss: 0.3749 Acc: 0.8566
val Loss: 0.2046 Acc: 0.9281

Epoch 16/24
----------
train Loss: 0.3702 Acc: 0.8197
val Loss: 0.1683 Acc: 0.9477

Epoch 17/24
----------
train Loss: 0.2380 Acc: 0.9016
val Loss: 0.1633 Acc: 0.9477

Epoch 18/24
----------
train Loss: 0.3156 Acc: 0.8566
val Loss: 0.1773 Acc: 0.9542

Epoch 19/24
----------
train Loss: 0.2901 Acc: 0.8689
val Loss: 0.2012 Acc: 0.9281

Epoch 20/24
----------
train Loss: 0.2951 Acc: 0.8975
val Loss: 0.1735 Acc: 0.9477

Epoch 21/24
----------
train Loss: 0.3652 Acc: 0.8361
val Loss: 0.1859 Acc: 0.9477

Epoch 22/24
----------
train Loss: 0.3250 Acc: 0.8566
val Loss: 0.1757 Acc: 0.9412

Epoch 23/24
----------
train Loss: 0.2952 Acc: 0.8607
val Loss: 0.1869 Acc: 0.9477

Epoch 24/24
----------
train Loss: 0.3594 Acc: 0.8484
val Loss: 0.1850 Acc: 0.9346

Training complete in 0m 28s
Best val Acc: 0.960784
visualize_model(model_conv)

plt.ioff()
plt.show()
predicted: bees, predicted: bees, predicted: ants, predicted: ants, predicted: ants, predicted: bees

自定义图像推断#

使用训练好的模型对自定义图像进行预测,并可视化预测的类别标签和图像。

def visualize_model_predictions(model,img_path):
    was_training = model.training
    model.eval()

    img = Image.open(img_path)
    img = data_transforms['val'](img)
    img = img.unsqueeze(0)
    img = img.to(device)

    with torch.no_grad():
        outputs = model(img)
        _, preds = torch.max(outputs, 1)

        ax = plt.subplot(2,2,1)
        ax.axis('off')
        ax.set_title(f'Predicted: {class_names[preds[0]]}')
        imshow(img.cpu().data[0])

        model.train(mode=was_training)
visualize_model_predictions(
    model_conv,
    img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'
)

plt.ioff()
plt.show()
Predicted: bees

进一步学习#

如果您想了解更多关于迁移学习的应用,请查看我们的 计算机视觉量化迁移学习教程

脚本总运行时间: (1 分钟 4.623 秒)