评价此页

TorchVision 目标检测微调教程#

创建于:2023年12月14日 | 最后更新:2024年6月11日 | 最后验证:2024年11月5日

在本教程中,我们将在 Penn-Fudan 行人检测与分割数据库上微调一个预训练的 Mask R-CNN 模型。该数据库包含170张图像和345个行人实例,我们将用它来演示如何使用 torchvision 中的新功能在自定义数据集上训练目标检测和实例分割模型。

注意

本教程仅适用于 torchvision 版本 >=0.16 或 nightly 版本。如果您使用的是 torchvision<=0.15,请改为参考此教程

定义数据集#

用于训练目标检测、实例分割和人体关键点检测的参考脚本可以轻松支持添加新的自定义数据集。数据集应继承自标准的 torch.utils.data.Dataset 类,并实现 __len____getitem__

我们唯一的要求是数据集的 __getitem__ 方法应该返回一个元组:

  • image: 一个形状为 [3, H, W]torchvision.tv_tensors.Image,一个纯张量,或者一个大小为 (H, W) 的 PIL 图像

  • target: 一个包含以下字段的字典

    • boxes,形状为 [N, 4]torchvision.tv_tensors.BoundingBoxesN 个边界框的坐标,格式为 [x0, y0, x1, y1],范围从 0W0H

    • labels,形状为 [N] 的整数 torch.Tensor:每个边界框的标签。 0 始终代表背景类。

    • image_id,int 类型:一个图像标识符。它在数据集中的所有图像之间应该是唯一的,并在评估期间使用

    • area,形状为 [N] 的浮点型 torch.Tensor:边界框的面积。这在 COCO 指标评估中使用,用于区分小、中、大框的指标分数。

    • iscrowd,形状为 [N] 的 uint8 型 torch.Tensor:在评估期间,iscrowd=True 的实例将被忽略。

    • (可选)masks,形状为 [N, H, W]torchvision.tv_tensors.Mask:每个对象的分割掩码

如果您的数据集符合上述要求,那么它将适用于参考脚本中的训练和评估代码。评估代码将使用 pycocotools 中的脚本,可以通过 pip install pycocotools 进行安装。

注意

对于 Windows,请使用以下命令从 gautamchitnis 安装 pycocotools

pip install git+https://github.com/gautamchitnis/cocoapi.git@cocodataset-master#subdirectory=PythonAPI

关于 labels 的一点说明。模型将类别 0 视为背景。如果您的数据集不包含背景类,那么您的 labels 中不应该有 0。例如,假设您只有两个类别,,您可以定义 1(而不是 0)来表示2 来表示。因此,举例来说,如果其中一张图片同时包含这两个类别,您的 labels 张量应该看起来像 [1, 2]

此外,如果您想在训练期间使用宽高比分组(以便每个批次只包含具有相似宽高比的图像),那么建议也实现一个 get_height_and_width 方法,该方法返回图像的高度和宽度。如果未提供此方法,我们将通过 __getitem__ 查询数据集的所有元素,这会将图像加载到内存中,比提供自定义方法要慢。

为 PennFudan 编写自定义数据集#

让我们为 PennFudan 数据集编写一个数据集。首先,让我们下载数据集并解压 zip 文件

wget https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip -P data
cd data && unzip PennFudanPed.zip

我们有以下文件夹结构

PennFudanPed/
  PedMasks/
    FudanPed00001_mask.png
    FudanPed00002_mask.png
    FudanPed00003_mask.png
    FudanPed00004_mask.png
    ...
  PNGImages/
    FudanPed00001.png
    FudanPed00002.png
    FudanPed00003.png
    FudanPed00004.png

这是一对图像和分割掩码的示例

import matplotlib.pyplot as plt
from torchvision.io import read_image


image = read_image("data/PennFudanPed/PNGImages/FudanPed00046.png")
mask = read_image("data/PennFudanPed/PedMasks/FudanPed00046_mask.png")

plt.figure(figsize=(16, 8))
plt.subplot(121)
plt.title("Image")
plt.imshow(image.permute(1, 2, 0))
plt.subplot(122)
plt.title("Mask")
plt.imshow(mask.permute(1, 2, 0))
Image, Mask
<matplotlib.image.AxesImage object at 0x7f4d5ffea050>

因此,每张图像都有一个对应的分割掩码,其中每种颜色对应一个不同的实例。让我们为这个数据集编写一个 torch.utils.data.Dataset 类。在下面的代码中,我们将图像、边界框和掩码包装到 torchvision.tv_tensors.TVTensor 类中,这样我们就可以为给定的目标检测和分割任务应用 torchvision 内置的变换(新的 Transforms API)。具体来说,图像张量将被 torchvision.tv_tensors.Image 包装,边界框被包装到 torchvision.tv_tensors.BoundingBoxes 中,掩码被包装到 torchvision.tv_tensors.Mask 中。由于 torchvision.tv_tensors.TVTensortorch.Tensor 的子类,被包装的对象也是张量,并继承了普通的 torch.Tensor API。有关 torchvision tv_tensors 的更多信息,请参阅此文档

import os
import torch

from torchvision.io import read_image
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F


class PennFudanDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))

    def __getitem__(self, idx):
        # load images and masks
        img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
        img = read_image(img_path)
        mask = read_image(mask_path)
        # instances are encoded as different colors
        obj_ids = torch.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]
        num_objs = len(obj_ids)

        # split the color-encoded mask into a set
        # of binary masks
        masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)

        # get bounding box coordinates for each mask
        boxes = masks_to_boxes(masks)

        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)

        image_id = idx
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        # Wrap sample and targets into torchvision tv_tensors:
        img = tv_tensors.Image(img)

        target = {}
        target["boxes"] = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
        target["masks"] = tv_tensors.Mask(masks)
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

数据集部分就到此为止。现在让我们定义一个可以在此数据集上进行预测的模型。

定义你的模型#

在本教程中,我们将使用 Mask R-CNN,它基于 Faster R-CNN。Faster R-CNN 是一个模型,可以预测图像中潜在对象的边界框和类别分数。

../_static/img/tv_tutorial/tv_image03.png

Mask R-CNN 在 Faster R-CNN 中增加了一个额外的分支,该分支还为每个实例预测分割掩码。

../_static/img/tv_tutorial/tv_image04.png

在 TorchVision Model Zoo 中,有两种常见情况可能需要修改现有模型。第一种是当我们想从一个预训练模型开始,只微调最后一层。另一种是当我们想用不同的主干网络替换模型的骨干部分(例如,为了更快的预测)。

让我们在接下来的部分中看看如何实现这两种方式。

1 - 从预训练模型微调#

假设您想从一个在 COCO 上预训练的模型开始,并希望针对您的特定类别进行微调。以下是一种可能的方法:

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth

  0%|          | 0.00/160M [00:00<?, ?B/s]
 11%|█         | 17.9M/160M [00:00<00:00, 187MB/s]
 24%|██▍       | 39.0M/160M [00:00<00:00, 207MB/s]
 37%|███▋      | 58.8M/160M [00:00<00:00, 200MB/s]
 49%|████▉     | 77.9M/160M [00:00<00:00, 187MB/s]
 60%|██████    | 95.9M/160M [00:00<00:00, 146MB/s]
 69%|██████▉   | 111M/160M [00:00<00:00, 132MB/s]
 84%|████████▎ | 134M/160M [00:00<00:00, 159MB/s]
 94%|█████████▍| 150M/160M [00:01<00:00, 145MB/s]
100%|██████████| 160M/160M [00:01<00:00, 157MB/s]

2 - 修改模型以添加不同的主干网络#

import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

# load a pre-trained model for classification and return
# only the features
backbone = torchvision.models.mobilenet_v2(weights="DEFAULT").features
# ``FasterRCNN`` needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280
# so we need to add it here
backbone.out_channels = 1280

# let's make the RPN generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios
anchor_generator = AnchorGenerator(
    sizes=((32, 64, 128, 256, 512),),
    aspect_ratios=((0.5, 1.0, 2.0),)
)

# let's define what are the feature maps that we will
# use to perform the region of interest cropping, as well as
# the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to
# be [0]. More generally, the backbone should return an
# ``OrderedDict[Tensor]``, and in ``featmap_names`` you can choose which
# feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
    featmap_names=['0'],
    output_size=7,
    sampling_ratio=2
)

# put the pieces together inside a Faster-RCNN model
model = FasterRCNN(
    backbone,
    num_classes=2,
    rpn_anchor_generator=anchor_generator,
    box_roi_pool=roi_pooler
)
Downloading: "https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/mobilenet_v2-7ebf99e0.pth

  0%|          | 0.00/13.6M [00:00<?, ?B/s]
100%|██████████| 13.6M/13.6M [00:00<00:00, 302MB/s]

用于 PennFudan 数据集的目标检测和实例分割模型#

在我们的案例中,我们希望从一个预训练的模型进行微调,因为我们的数据集非常小,所以我们将遵循第一种方法。

这里我们还想计算实例分割掩码,所以我们将使用 Mask R-CNN。

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )

    return model

就是这样,这将使 model 准备好在您的自定义数据集上进行训练和评估。

整合所有内容#

references/detection/ 中,我们有许多辅助函数来简化检测模型的训练和评估。这里,我们将使用 references/detection/engine.pyreferences/detection/utils.py。只需将 references/detection 下的所有文件下载到您的文件夹中,并在此处使用它们。在 Linux 上,如果您有 wget,可以使用以下命令下载它们

os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/engine.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/utils.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_utils.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_eval.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/transforms.py")
0

自 v0.15.0 起,torchvision 提供了新的 Transforms API,可以轻松地为目标检测和分割任务编写数据增强流程。

让我们编写一些用于数据增强/变换的辅助函数

from torchvision.transforms import v2 as T


def get_transform(train):
    transforms = []
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)

测试 forward() 方法(可选)#

在迭代数据集之前,最好先看看模型在训练和推理时对样本数据的期望输入是什么。

import utils

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
dataset = PennFudanDataset('data/PennFudanPed', get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=utils.collate_fn
)

# For Training
images, targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images, targets)  # Returns losses and detections
print(output)

# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)  # Returns predictions
print(predictions[0])
{'loss_classifier': tensor(0.0548, grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(0.0224, grad_fn=<DivBackward0>), 'loss_objectness': tensor(0.0088, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_rpn_box_reg': tensor(0.0036, grad_fn=<DivBackward0>)}
{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward0>)}

现在让我们编写执行训练和验证的主函数

from engine import train_one_epoch, evaluate

# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has two classes only - background and person
num_classes = 2
# use our dataset and defined transformations
dataset = PennFudanDataset('data/PennFudanPed', get_transform(train=True))
dataset_test = PennFudanDataset('data/PennFudanPed', get_transform(train=False))

# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=utils.collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=1,
    shuffle=False,
    collate_fn=utils.collate_fn
)

# get the model using our helper function
model = get_model_instance_segmentation(num_classes)

# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

# let's train it just for 2 epochs
num_epochs = 2

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

print("That's it!")
Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

  0%|          | 0.00/170M [00:00<?, ?B/s]
 16%|█▌        | 26.8M/170M [00:00<00:00, 280MB/s]
 32%|███▏      | 53.5M/170M [00:00<00:00, 241MB/s]
 45%|████▌     | 76.9M/170M [00:00<00:00, 226MB/s]
 63%|██████▎   | 106M/170M [00:00<00:00, 255MB/s]
 77%|███████▋  | 131M/170M [00:00<00:00, 257MB/s]
 92%|█████████▏| 156M/170M [00:00<00:00, 218MB/s]
100%|██████████| 170M/170M [00:00<00:00, 237MB/s]
/var/lib/workspace/intermediate_source/engine.py:30: FutureWarning:

`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.

Epoch: [0]  [ 0/60]  eta: 0:00:41  lr: 0.000090  loss: 3.5105 (3.5105)  loss_classifier: 0.7271 (0.7271)  loss_box_reg: 0.4294 (0.4294)  loss_mask: 2.3374 (2.3374)  loss_objectness: 0.0123 (0.0123)  loss_rpn_box_reg: 0.0043 (0.0043)  time: 0.6965  data: 0.0178  max mem: 2339
Epoch: [0]  [10/60]  eta: 0:00:12  lr: 0.000936  loss: 1.6036 (2.0725)  loss_classifier: 0.4683 (0.4551)  loss_box_reg: 0.2096 (0.2884)  loss_mask: 0.8901 (1.3081)  loss_objectness: 0.0138 (0.0165)  loss_rpn_box_reg: 0.0043 (0.0043)  time: 0.2576  data: 0.0154  max mem: 2625
Epoch: [0]  [20/60]  eta: 0:00:09  lr: 0.001783  loss: 0.8523 (1.4413)  loss_classifier: 0.2064 (0.3259)  loss_box_reg: 0.2175 (0.2835)  loss_mask: 0.3499 (0.8078)  loss_objectness: 0.0195 (0.0192)  loss_rpn_box_reg: 0.0040 (0.0050)  time: 0.2174  data: 0.0167  max mem: 2671
Epoch: [0]  [30/60]  eta: 0:00:06  lr: 0.002629  loss: 0.6438 (1.1489)  loss_classifier: 0.1037 (0.2442)  loss_box_reg: 0.2175 (0.2563)  loss_mask: 0.2364 (0.6239)  loss_objectness: 0.0157 (0.0183)  loss_rpn_box_reg: 0.0049 (0.0061)  time: 0.2182  data: 0.0172  max mem: 2671
Epoch: [0]  [40/60]  eta: 0:00:04  lr: 0.003476  loss: 0.4668 (0.9805)  loss_classifier: 0.0503 (0.2006)  loss_box_reg: 0.1692 (0.2431)  loss_mask: 0.2048 (0.5153)  loss_objectness: 0.0103 (0.0155)  loss_rpn_box_reg: 0.0037 (0.0059)  time: 0.2100  data: 0.0150  max mem: 2671
Epoch: [0]  [50/60]  eta: 0:00:02  lr: 0.004323  loss: 0.4090 (0.8642)  loss_classifier: 0.0419 (0.1699)  loss_box_reg: 0.1485 (0.2242)  loss_mask: 0.1784 (0.4503)  loss_objectness: 0.0023 (0.0136)  loss_rpn_box_reg: 0.0068 (0.0062)  time: 0.2064  data: 0.0148  max mem: 2671
Epoch: [0]  [59/60]  eta: 0:00:00  lr: 0.005000  loss: 0.3670 (0.7873)  loss_classifier: 0.0371 (0.1515)  loss_box_reg: 0.1286 (0.2088)  loss_mask: 0.1784 (0.4087)  loss_objectness: 0.0021 (0.0119)  loss_rpn_box_reg: 0.0071 (0.0064)  time: 0.2054  data: 0.0155  max mem: 2827
Epoch: [0] Total time: 0:00:13 (0.2194 s / it)
creating index...
index created!
Test:  [ 0/50]  eta: 0:00:05  model_time: 0.0814 (0.0814)  evaluator_time: 0.0062 (0.0062)  time: 0.1040  data: 0.0158  max mem: 2827
Test:  [49/50]  eta: 0:00:00  model_time: 0.0391 (0.0557)  evaluator_time: 0.0030 (0.0051)  time: 0.0624  data: 0.0103  max mem: 2827
Test: Total time: 0:00:03 (0.0713 s / it)
Averaged stats: model_time: 0.0391 (0.0557)  evaluator_time: 0.0030 (0.0051)
Accumulating evaluation results...
DONE (t=0.01s).
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.717
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.992
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.929
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.608
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.722
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.311
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.766
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.766
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.773
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.765
IoU metric: segm
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.694
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.992
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.855
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.464
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.708
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.307
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.727
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.728
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.636
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.737
Epoch: [1]  [ 0/60]  eta: 0:00:10  lr: 0.005000  loss: 0.2681 (0.2681)  loss_classifier: 0.0252 (0.0252)  loss_box_reg: 0.0840 (0.0840)  loss_mask: 0.1534 (0.1534)  loss_objectness: 0.0016 (0.0016)  loss_rpn_box_reg: 0.0039 (0.0039)  time: 0.1753  data: 0.0202  max mem: 2827
Epoch: [1]  [10/60]  eta: 0:00:10  lr: 0.005000  loss: 0.2788 (0.3368)  loss_classifier: 0.0360 (0.0426)  loss_box_reg: 0.0709 (0.1065)  loss_mask: 0.1613 (0.1768)  loss_objectness: 0.0016 (0.0036)  loss_rpn_box_reg: 0.0061 (0.0073)  time: 0.2067  data: 0.0176  max mem: 2827
Epoch: [1]  [20/60]  eta: 0:00:08  lr: 0.005000  loss: 0.2898 (0.3246)  loss_classifier: 0.0390 (0.0426)  loss_box_reg: 0.0828 (0.1067)  loss_mask: 0.1527 (0.1656)  loss_objectness: 0.0013 (0.0026)  loss_rpn_box_reg: 0.0068 (0.0071)  time: 0.2113  data: 0.0165  max mem: 2827
Epoch: [1]  [30/60]  eta: 0:00:06  lr: 0.005000  loss: 0.3002 (0.3168)  loss_classifier: 0.0414 (0.0412)  loss_box_reg: 0.1022 (0.1059)  loss_mask: 0.1479 (0.1608)  loss_objectness: 0.0008 (0.0027)  loss_rpn_box_reg: 0.0052 (0.0063)  time: 0.2079  data: 0.0148  max mem: 2827
Epoch: [1]  [40/60]  eta: 0:00:04  lr: 0.005000  loss: 0.2758 (0.2997)  loss_classifier: 0.0343 (0.0387)  loss_box_reg: 0.0838 (0.0994)  loss_mask: 0.1405 (0.1537)  loss_objectness: 0.0007 (0.0025)  loss_rpn_box_reg: 0.0027 (0.0055)  time: 0.2013  data: 0.0146  max mem: 2827
Epoch: [1]  [50/60]  eta: 0:00:02  lr: 0.005000  loss: 0.2522 (0.2937)  loss_classifier: 0.0327 (0.0388)  loss_box_reg: 0.0676 (0.0945)  loss_mask: 0.1326 (0.1527)  loss_objectness: 0.0007 (0.0022)  loss_rpn_box_reg: 0.0026 (0.0053)  time: 0.2071  data: 0.0156  max mem: 2938
Epoch: [1]  [59/60]  eta: 0:00:00  lr: 0.005000  loss: 0.2579 (0.2899)  loss_classifier: 0.0374 (0.0382)  loss_box_reg: 0.0696 (0.0927)  loss_mask: 0.1416 (0.1515)  loss_objectness: 0.0009 (0.0021)  loss_rpn_box_reg: 0.0046 (0.0053)  time: 0.2099  data: 0.0157  max mem: 2938
Epoch: [1] Total time: 0:00:12 (0.2075 s / it)
creating index...
index created!
Test:  [ 0/50]  eta: 0:00:03  model_time: 0.0472 (0.0472)  evaluator_time: 0.0041 (0.0041)  time: 0.0676  data: 0.0158  max mem: 2938
Test:  [49/50]  eta: 0:00:00  model_time: 0.0380 (0.0404)  evaluator_time: 0.0026 (0.0037)  time: 0.0551  data: 0.0103  max mem: 2938
Test: Total time: 0:00:02 (0.0545 s / it)
Averaged stats: model_time: 0.0380 (0.0404)  evaluator_time: 0.0026 (0.0037)
Accumulating evaluation results...
DONE (t=0.01s).
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.793
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.993
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.963
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.658
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.801
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.346
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.825
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.825
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.764
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.831
IoU metric: segm
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.766
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.993
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.949
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.557
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.777
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.332
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.798
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.798
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.727
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.804
That's it!

因此,经过一个周期的训练,我们得到了一个大于 50 的 COCO 风格 mAP,以及一个 65 的掩码 mAP。

但是预测结果看起来如何呢?让我们从数据集中取一张图片来验证一下

import matplotlib.pyplot as plt

from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks


image = read_image("data/PennFudanPed/PNGImages/FudanPed00046.png")
eval_transform = get_transform(train=False)

model.eval()
with torch.no_grad():
    x = eval_transform(image)
    # convert RGBA -> RGB and move to device
    x = x[:3, ...].to(device)
    predictions = model([x, ])
    pred = predictions[0]


image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
image = image[:3, ...]
pred_labels = [f"pedestrian: {score:.3f}" for label, score in zip(pred["labels"], pred["scores"])]
pred_boxes = pred["boxes"].long()
output_image = draw_bounding_boxes(image, pred_boxes, pred_labels, colors="red")

masks = (pred["masks"] > 0.7).squeeze(1)
output_image = draw_segmentation_masks(output_image, masks, alpha=0.5, colors="blue")


plt.figure(figsize=(12, 12))
plt.imshow(output_image.permute(1, 2, 0))
torchvision tutorial
<matplotlib.image.AxesImage object at 0x7f4db35f1360>

结果看起来不错!

总结#

在本教程中,您学习了如何为自定义数据集上的目标检测模型创建自己的训练流程。为此,您编写了一个 torch.utils.data.Dataset 类,它返回图像以及真实的边界框和分割掩码。您还利用了一个在 COCO train2017 上预训练的 Mask R-CNN 模型,以便在这个新数据集上执行迁移学习。

要查看一个更完整的示例,包括多机/多 GPU 训练,请查看 torchvision 仓库中的 references/detection/train.py

脚本总运行时间: (0 分 46.926 秒)