评价此页

计算机视觉迁移学习教程#

创建日期:2017年3月24日 | 最后更新:2025年1月27日 | 最后验证:2024年11月5日

作者Sasank Chilamkurthy

在本教程中,你将学习如何使用迁移学习训练一个用于图像分类的卷积神经网络。你可以通过 cs231n 笔记 阅读更多关于迁移学习的内容。

引用这些笔记:

在实践中,很少有人从头开始(使用随机初始化)训练整个卷积网络,因为拥有足够规模的数据集是相对罕见的。相反,通常的做法是在一个非常大的数据集(例如 ImageNet,包含 120 万张图片和 1000 个类别)上预训练一个卷积网络(ConvNet),然后将该卷积网络用作感兴趣任务的初始化权重或固定的特征提取器。

这两种主要的迁移学习场景如下:

  • 微调卷积网络 (Finetuning the ConvNet):我们不使用随机初始化,而是用预训练的网络(如在 ImageNet 1000 数据集上训练的网络)来初始化网络。其余的训练过程与往常一样。

  • 将卷积网络作为固定特征提取器 (ConvNet as fixed feature extractor):这里,我们将冻结除最后一层全连接层之外的所有网络权重。这个最后的全连接层被替换为一个具有随机权重的新层,并且只训练这一层。

# License: BSD
# Author: Sasank Chilamkurthy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory

cudnn.benchmark = True
plt.ion()   # interactive mode
<contextlib.ExitStack object at 0x7f1c90e24640>

加载数据#

我们将使用 torchvision 和 torch.utils.data 包来加载数据。

我们今天要解决的问题是训练一个模型来分类蚂蚁蜜蜂。我们大约有 120 张蚂蚁和蜜蜂的训练图像。每个类别有 75 张验证图像。通常,如果从头开始训练,这是一个非常小的数据集,难以泛化。由于我们使用迁移学习,我们应该能够实现相当不错的泛化效果。

该数据集是 ImageNet 的一个非常小的子集。

注意

此处下载数据并将其解压到当前目录。

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# We want to be able to train our model on an `accelerator <https://pytorch.ac.cn/docs/stable/torch.html#accelerators>`__
# such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")
Using cuda device

可视化部分图像#

让我们可视化一些训练图像,以便了解数据增强的情况。

def imshow(inp, title=None):
    """Display image for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])
['bees', 'bees', 'ants', 'ants']

训练模型#

现在,让我们编写一个通用函数来训练模型。在这里,我们将演示

  • 学习率调度

  • 保存最佳模型

在下面,参数 scheduler 是来自 torch.optim.lr_scheduler 的 LR 调度器对象。

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')

        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()   # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                if phase == 'train':
                    scheduler.step()

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                # deep copy the model
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        # load best model weights
        model.load_state_dict(torch.load(best_model_params_path, weights_only=True))
    return model

可视化模型预测#

用于显示少量图像预测结果的通用函数

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

微调卷积网络#

加载预训练模型并重置最后一层全连接层。

model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

  0%|          | 0.00/44.7M [00:00<?, ?B/s]
 93%|█████████▎| 41.8M/44.7M [00:00<00:00, 437MB/s]
100%|██████████| 44.7M/44.7M [00:00<00:00, 437MB/s]

训练并评估#

在 CPU 上大约需要 15-25 分钟。而在 GPU 上,则不到一分钟。

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)
Epoch 0/24
----------
train Loss: 0.5903 Acc: 0.7008
val Loss: 0.2170 Acc: 0.9216

Epoch 1/24
----------
train Loss: 0.4710 Acc: 0.8156
val Loss: 0.2947 Acc: 0.9020

Epoch 2/24
----------
train Loss: 0.4788 Acc: 0.8033
val Loss: 0.2117 Acc: 0.9412

Epoch 3/24
----------
train Loss: 0.4692 Acc: 0.8238
val Loss: 0.1701 Acc: 0.9346

Epoch 4/24
----------
train Loss: 0.4088 Acc: 0.8361
val Loss: 0.3131 Acc: 0.8693

Epoch 5/24
----------
train Loss: 0.4055 Acc: 0.8320
val Loss: 0.2013 Acc: 0.9281

Epoch 6/24
----------
train Loss: 0.3931 Acc: 0.8648
val Loss: 0.4235 Acc: 0.8366

Epoch 7/24
----------
train Loss: 0.3563 Acc: 0.8156
val Loss: 0.3095 Acc: 0.9020

Epoch 8/24
----------
train Loss: 0.2656 Acc: 0.9098
val Loss: 0.3170 Acc: 0.9020

Epoch 9/24
----------
train Loss: 0.3295 Acc: 0.8566
val Loss: 0.2846 Acc: 0.8954

Epoch 10/24
----------
train Loss: 0.2747 Acc: 0.8770
val Loss: 0.2425 Acc: 0.9216

Epoch 11/24
----------
train Loss: 0.2622 Acc: 0.8607
val Loss: 0.2274 Acc: 0.9085

Epoch 12/24
----------
train Loss: 0.2942 Acc: 0.8689
val Loss: 0.2726 Acc: 0.8889

Epoch 13/24
----------
train Loss: 0.2743 Acc: 0.8770
val Loss: 0.1998 Acc: 0.9085

Epoch 14/24
----------
train Loss: 0.2348 Acc: 0.8852
val Loss: 0.1951 Acc: 0.9020

Epoch 15/24
----------
train Loss: 0.2979 Acc: 0.8689
val Loss: 0.2160 Acc: 0.8889

Epoch 16/24
----------
train Loss: 0.2774 Acc: 0.8730
val Loss: 0.2210 Acc: 0.9020

Epoch 17/24
----------
train Loss: 0.2253 Acc: 0.9180
val Loss: 0.2172 Acc: 0.9020

Epoch 18/24
----------
train Loss: 0.2654 Acc: 0.8648
val Loss: 0.2134 Acc: 0.9216

Epoch 19/24
----------
train Loss: 0.3400 Acc: 0.8607
val Loss: 0.2012 Acc: 0.9150

Epoch 20/24
----------
train Loss: 0.2938 Acc: 0.8811
val Loss: 0.2054 Acc: 0.9216

Epoch 21/24
----------
train Loss: 0.2120 Acc: 0.9016
val Loss: 0.2008 Acc: 0.9281

Epoch 22/24
----------
train Loss: 0.2631 Acc: 0.8975
val Loss: 0.2023 Acc: 0.9085

Epoch 23/24
----------
train Loss: 0.3325 Acc: 0.8197
val Loss: 0.1949 Acc: 0.9281

Epoch 24/24
----------
train Loss: 0.2967 Acc: 0.8770
val Loss: 0.2071 Acc: 0.8954

Training complete in 0m 36s
Best val Acc: 0.941176
visualize_model(model_ft)
predicted: bees, predicted: bees, predicted: ants, predicted: bees, predicted: ants, predicted: ants

将卷积网络作为固定特征提取器#

在这里,我们需要冻结除最后一层之外的所有网络。我们需要设置 requires_grad = False 来冻结参数,以便在 backward() 中不计算梯度。

你可以在文档此处阅读更多关于此的内容。

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

训练并评估#

在 CPU 上,这将花费大约上一场景一半的时间。这是预料之中的,因为大部分网络不需要计算梯度。但是,仍然需要计算前向传播。

model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)
Epoch 0/24
----------
train Loss: 0.6331 Acc: 0.6434
val Loss: 0.2510 Acc: 0.9216

Epoch 1/24
----------
train Loss: 0.6771 Acc: 0.6967
val Loss: 0.2352 Acc: 0.9020

Epoch 2/24
----------
train Loss: 0.4537 Acc: 0.8197
val Loss: 0.2055 Acc: 0.9412

Epoch 3/24
----------
train Loss: 0.4050 Acc: 0.8033
val Loss: 0.1749 Acc: 0.9281

Epoch 4/24
----------
train Loss: 0.6083 Acc: 0.7746
val Loss: 0.1692 Acc: 0.9412

Epoch 5/24
----------
train Loss: 0.4104 Acc: 0.8361
val Loss: 0.2291 Acc: 0.9281

Epoch 6/24
----------
train Loss: 0.3597 Acc: 0.8320
val Loss: 0.1661 Acc: 0.9477

Epoch 7/24
----------
train Loss: 0.3475 Acc: 0.8648
val Loss: 0.1690 Acc: 0.9477

Epoch 8/24
----------
train Loss: 0.3004 Acc: 0.8893
val Loss: 0.2063 Acc: 0.9477

Epoch 9/24
----------
train Loss: 0.3520 Acc: 0.8648
val Loss: 0.1741 Acc: 0.9412

Epoch 10/24
----------
train Loss: 0.3576 Acc: 0.8402
val Loss: 0.1614 Acc: 0.9542

Epoch 11/24
----------
train Loss: 0.3188 Acc: 0.8525
val Loss: 0.1865 Acc: 0.9542

Epoch 12/24
----------
train Loss: 0.3279 Acc: 0.8648
val Loss: 0.1853 Acc: 0.9542

Epoch 13/24
----------
train Loss: 0.2537 Acc: 0.8893
val Loss: 0.2184 Acc: 0.9216

Epoch 14/24
----------
train Loss: 0.3280 Acc: 0.8607
val Loss: 0.1771 Acc: 0.9608

Epoch 15/24
----------
train Loss: 0.3321 Acc: 0.8648
val Loss: 0.1889 Acc: 0.9412

Epoch 16/24
----------
train Loss: 0.2804 Acc: 0.8730
val Loss: 0.1907 Acc: 0.9477

Epoch 17/24
----------
train Loss: 0.3353 Acc: 0.8443
val Loss: 0.2022 Acc: 0.9412

Epoch 18/24
----------
train Loss: 0.2714 Acc: 0.8852
val Loss: 0.1664 Acc: 0.9477

Epoch 19/24
----------
train Loss: 0.3275 Acc: 0.8156
val Loss: 0.1865 Acc: 0.9542

Epoch 20/24
----------
train Loss: 0.3696 Acc: 0.8443
val Loss: 0.1755 Acc: 0.9608

Epoch 21/24
----------
train Loss: 0.2963 Acc: 0.8607
val Loss: 0.1840 Acc: 0.9477

Epoch 22/24
----------
train Loss: 0.3033 Acc: 0.8730
val Loss: 0.1836 Acc: 0.9477

Epoch 23/24
----------
train Loss: 0.3573 Acc: 0.8361
val Loss: 0.1755 Acc: 0.9608

Epoch 24/24
----------
train Loss: 0.3503 Acc: 0.8648
val Loss: 0.1618 Acc: 0.9477

Training complete in 0m 28s
Best val Acc: 0.960784
visualize_model(model_conv)

plt.ioff()
plt.show()
predicted: ants, predicted: bees, predicted: ants, predicted: bees, predicted: bees, predicted: ants

自定义图像推断#

使用训练好的模型对自定义图像进行预测,并可视化预测的类别标签及对应的图像。

def visualize_model_predictions(model,img_path):
    was_training = model.training
    model.eval()

    img = Image.open(img_path)
    img = data_transforms['val'](img)
    img = img.unsqueeze(0)
    img = img.to(device)

    with torch.no_grad():
        outputs = model(img)
        _, preds = torch.max(outputs, 1)

        ax = plt.subplot(2,2,1)
        ax.axis('off')
        ax.set_title(f'Predicted: {class_names[preds[0]]}')
        imshow(img.cpu().data[0])

        model.train(mode=was_training)
visualize_model_predictions(
    model_conv,
    img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'
)

plt.ioff()
plt.show()
Predicted: bees

深入学习#

如果你想了解更多关于迁移学习的应用,请查看我们的计算机视觉量化迁移学习教程

脚本总运行时间:(1 分钟 6.492 秒)