注意
跳转到最后 下载完整的示例代码。
TorchVision 对象检测微调教程#
创建于:2023 年 12 月 14 日 | 最后更新:2025 年 9 月 5 日 | 最后验证:2024 年 11 月 5 日
在本教程中,我们将对预训练的 Mask R-CNN 模型进行微调,并在 宾夕法尼亚大学行人检测和分割数据库 (Penn-Fudan Database for Pedestrian Detection and Segmentation) 上进行训练。该数据库包含 170 张图像,其中有 345 个行人实例,我们将使用它来演示如何利用 torchvision 的新特性在自定义数据集上训练对象检测和实例分割模型。
注意
此教程仅适用于 torchvision 版本 >=0.16 或 nightly 版本。如果您使用的是 torchvision<=0.15,请参考 此教程。
定义数据集#
用于训练对象检测、实例分割和关键点检测的参考脚本可以轻松支持添加新的自定义数据集。数据集应继承自标准的 torch.utils.data.Dataset 类,并实现 __len__ 和 __getitem__ 方法。
我们唯一的要求是,数据集的 __getitem__ 方法应该返回一个元组
image: 形状为
[3, H, W]的torchvision.tv_tensors.Image,一个纯张量,或者一个大小为(H, W)的 PIL 图像target: 一个包含以下字段的字典
boxes, 形状为[N, 4]的torchvision.tv_tensors.BoundingBoxes:N个边界框的坐标,格式为[x0, y0, x1, y1],范围从0到W和0到Hlabels, 形状为[N]的整数torch.Tensor:每个边界框的标签。0始终代表背景类。image_id, int: 图像标识符。它应该在数据集的所有图像中是唯一的,并在评估期间使用。area, 形状为[N]的浮点数torch.Tensor:边界框的面积。在 COCO 指标评估期间使用,用于区分小、中、大框的指标得分。iscrowd, 形状为[N]的 uint8torch.Tensor:iscrowd=True的实例在评估期间将被忽略。(可选)
masks, 形状为[N, H, W]的torchvision.tv_tensors.Mask:每个对象的分割掩码。
如果您的数据集符合上述要求,那么它将适用于参考脚本中的训练和评估代码。评估代码将使用来自 pycocotools 的脚本,可以使用 pip install pycocotools 进行安装。
注意
对于 Windows,请从 gautamchitnis 安装 pycocotools,命令如下:
pip install git+https://github.com/gautamchitnis/cocoapi.git@cocodataset-master#subdirectory=PythonAPI
关于 labels 的一个说明。模型将类 0 视为背景。如果您的数据集不包含背景类,则不应在 labels 中包含 0。例如,假设您只有两个类:猫 和 狗,您可以将 1 (而不是 0) 定义为代表 猫,将 2 定义为代表 狗。因此,例如,如果一张图像同时包含这两个类,您的 labels 张量应该看起来像 [1, 2]。
此外,如果您想在训练期间使用纵横比分组(以便每个批次只包含具有相似纵横比的图像),则建议也实现一个 get_height_and_width 方法,该方法返回图像的高度和宽度。如果未提供此方法,我们将通过 __getitem__ 查询数据集的所有元素,这会加载图像到内存中,比提供自定义方法慢。
为 PennFudan 编写自定义数据集#
让我们为 PennFudan 数据集编写一个数据集。首先,下载数据集并解压 zip 文件。
wget https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip -P data
cd data && unzip PennFudanPed.zip
我们有以下文件夹结构
PennFudanPed/
PedMasks/
FudanPed00001_mask.png
FudanPed00002_mask.png
FudanPed00003_mask.png
FudanPed00004_mask.png
...
PNGImages/
FudanPed00001.png
FudanPed00002.png
FudanPed00003.png
FudanPed00004.png
这是一对图像和分割掩码的示例
import matplotlib.pyplot as plt
from torchvision.io import read_image
image = read_image("data/PennFudanPed/PNGImages/FudanPed00046.png")
mask = read_image("data/PennFudanPed/PedMasks/FudanPed00046_mask.png")
plt.figure(figsize=(16, 8))
plt.subplot(121)
plt.title("Image")
plt.imshow(image.permute(1, 2, 0))
plt.subplot(122)
plt.title("Mask")
plt.imshow(mask.permute(1, 2, 0))

<matplotlib.image.AxesImage object at 0x7f8955babee0>
因此,每张图像都有一个对应的分割掩码,其中每种颜色代表一个不同的实例。让我们为该数据集编写一个 torch.utils.data.Dataset 类。在下面的代码中,我们将图像、边界框和掩码封装到 torchvision.tv_tensors.TVTensor 类中,以便能够为给定的对象检测和分割任务应用 torchvision 内置的变换(新的 Transforms API)。具体来说,图像张量将被 torchvision.tv_tensors.Image 封装,边界框将被 torchvision.tv_tensors.BoundingBoxes 封装,掩码将被 torchvision.tv_tensors.Mask 封装。由于 torchvision.tv_tensors.TVTensor 是 torch.Tensor 的子类,因此封装的对象也是张量,并继承了普通的 torch.Tensor API。有关 torchvision tv_tensors 的更多信息,请参阅 此文档。
import os
import torch
from torchvision.io import read_image
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
class PennFudanDataset(torch.utils.data.Dataset):
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
# load all image files, sorting them to
# ensure that they are aligned
self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
def __getitem__(self, idx):
# load images and masks
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
img = read_image(img_path)
mask = read_image(mask_path)
# instances are encoded as different colors
obj_ids = torch.unique(mask)
# first id is the background, so remove it
obj_ids = obj_ids[1:]
num_objs = len(obj_ids)
# split the color-encoded mask into a set
# of binary masks
masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)
# get bounding box coordinates for each mask
boxes = masks_to_boxes(masks)
# there is only one class
labels = torch.ones((num_objs,), dtype=torch.int64)
image_id = idx
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# suppose all instances are not crowd
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
# Wrap sample and targets into torchvision tv_tensors:
img = tv_tensors.Image(img)
target = {}
target["boxes"] = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
target["masks"] = tv_tensors.Mask(masks)
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.imgs)
关于数据集就这么多。现在让我们定义一个模型,它可以对这个数据集进行预测。
定义模型#
在本教程中,我们将使用 Mask R-CNN,它是基于 Faster R-CNN 的。Faster R-CNN 是一个同时预测图像中潜在对象的边界框和类别分数的模型。
Mask R-CNN 在 Faster R-CNN 的基础上增加了一个额外的分支,该分支还可以预测每个实例的分割掩码。
在两种常见情况下,我们可能需要修改 TorchVision Model Zoo 中提供的模型之一。第一种情况是我们想从预训练模型开始,只微调最后一层。另一种情况是我们想用不同的骨干网络替换模型的骨干网络(例如,为了更快的预测)。
在接下来的章节中,我们将分别介绍如何实现这两种方法。
1 - 从预训练模型进行微调#
假设您想从在 COCO 上预训练的模型开始,并想为您的特定类别进行微调。以下是一种可能的方法:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
# load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2 # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
0%| | 0.00/160M [00:00<?, ?B/s]
26%|██▌ | 41.0M/160M [00:00<00:00, 429MB/s]
52%|█████▏ | 83.1M/160M [00:00<00:00, 436MB/s]
78%|███████▊ | 125M/160M [00:00<00:00, 419MB/s]
100%|██████████| 160M/160M [00:00<00:00, 426MB/s]
2 - 修改模型以添加不同的骨干网络#
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
# load a pre-trained model for classification and return
# only the features
backbone = torchvision.models.mobilenet_v2(weights="DEFAULT").features
# ``FasterRCNN`` needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280
# so we need to add it here
backbone.out_channels = 1280
# let's make the RPN generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios
anchor_generator = AnchorGenerator(
sizes=((32, 64, 128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),)
)
# let's define what are the feature maps that we will
# use to perform the region of interest cropping, as well as
# the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to
# be [0]. More generally, the backbone should return an
# ``OrderedDict[Tensor]``, and in ``featmap_names`` you can choose which
# feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
featmap_names=['0'],
output_size=7,
sampling_ratio=2
)
# put the pieces together inside a Faster-RCNN model
model = FasterRCNN(
backbone,
num_classes=2,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler
)
Downloading: "https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/mobilenet_v2-7ebf99e0.pth
0%| | 0.00/13.6M [00:00<?, ?B/s]
100%|██████████| 13.6M/13.6M [00:00<00:00, 282MB/s]
用于 PennFudan 数据集的对象检测和实例分割模型#
在我们的情况下,由于数据集很小,我们想从预训练模型进行微调,因此我们将遵循方法 1。
在这里,我们还想计算实例分割掩码,因此我们将使用 Mask R-CNN。
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(
in_features_mask,
hidden_layer,
num_classes
)
return model
这样,model 就准备好在您的自定义数据集上进行训练和评估了。
整合所有内容#
在 references/detection/ 目录下,我们提供了一些辅助函数,用于简化检测模型的训练和评估。在这里,我们将使用 references/detection/engine.py 和 references/detection/utils.py。只需将 references/detection 下的所有文件下载到您的文件夹中即可在此处使用。在 Linux 上,如果您有 wget,您可以使用以下命令下载它们:
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/engine.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/utils.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_utils.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_eval.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/transforms.py")
0
自 0.15.0 版本起,torchvision 提供了 新的 Transforms API,可以轻松地为对象检测和分割任务编写数据增强管道。
让我们编写一些用于数据增强/变换的辅助函数。
from torchvision.transforms import v2 as T
def get_transform(train):
transforms = []
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
transforms.append(T.ToDtype(torch.float, scale=True))
transforms.append(T.ToPureTensor())
return T.Compose(transforms)
测试 forward() 方法(可选)#
在遍历数据集之前,最好了解模型在训练和推理期间对样本数据的期望。
import utils
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
dataset = PennFudanDataset('data/PennFudanPed', get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=utils.collate_fn
)
# For Training
images, targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images, targets) # Returns losses and detections
print(output)
# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x) # Returns predictions
print(predictions[0])
{'loss_classifier': tensor(0.0396, grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(0.0487, grad_fn=<DivBackward0>), 'loss_objectness': tensor(0.0076, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_rpn_box_reg': tensor(0.0048, grad_fn=<DivBackward0>)}
{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward0>)}
我们希望能够在 加速器(如 CUDA、MPS、MTIA 或 XPU)上训练我们的模型。现在让我们编写执行训练和验证的主函数。
from engine import train_one_epoch, evaluate
# train on the accelerator or on the CPU, if an accelerator is not available
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device('cpu')
# our dataset has two classes only - background and person
num_classes = 2
# use our dataset and defined transformations
dataset = PennFudanDataset('data/PennFudanPed', get_transform(train=True))
dataset_test = PennFudanDataset('data/PennFudanPed', get_transform(train=False))
# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=utils.collate_fn
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=1,
shuffle=False,
collate_fn=utils.collate_fn
)
# get the model using our helper function
model = get_model_instance_segmentation(num_classes)
# move model to the right device
model.to(device)
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
params,
lr=0.005,
momentum=0.9,
weight_decay=0.0005
)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=3,
gamma=0.1
)
# let's train it just for 2 epochs
num_epochs = 2
for epoch in range(num_epochs):
# train for one epoch, printing every 10 iterations
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
# update the learning rate
lr_scheduler.step()
# evaluate on the test dataset
evaluate(model, data_loader_test, device=device)
print("That's it!")
Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
0%| | 0.00/170M [00:00<?, ?B/s]
25%|██▍ | 42.0M/170M [00:00<00:00, 439MB/s]
49%|████▉ | 84.0M/170M [00:00<00:00, 435MB/s]
75%|███████▍ | 127M/170M [00:00<00:00, 442MB/s]
100%|██████████| 170M/170M [00:00<00:00, 444MB/s]
/var/lib/workspace/intermediate_source/engine.py:30: FutureWarning:
`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
Epoch: [0] [ 0/60] eta: 0:00:39 lr: 0.000090 loss: 3.9297 (3.9297) loss_classifier: 0.8942 (0.8942) loss_box_reg: 0.2116 (0.2116) loss_mask: 2.7965 (2.7965) loss_objectness: 0.0242 (0.0242) loss_rpn_box_reg: 0.0032 (0.0032) time: 0.6557 data: 0.0122 max mem: 1753
Epoch: [0] [10/60] eta: 0:00:12 lr: 0.000936 loss: 1.4103 (1.9439) loss_classifier: 0.4974 (0.5230) loss_box_reg: 0.2421 (0.2619) loss_mask: 0.7706 (1.1278) loss_objectness: 0.0242 (0.0272) loss_rpn_box_reg: 0.0032 (0.0039) time: 0.2470 data: 0.0153 max mem: 2759
Epoch: [0] [20/60] eta: 0:00:09 lr: 0.001783 loss: 0.8459 (1.3453) loss_classifier: 0.2110 (0.3548) loss_box_reg: 0.2549 (0.2574) loss_mask: 0.3053 (0.7064) loss_objectness: 0.0116 (0.0229) loss_rpn_box_reg: 0.0030 (0.0039) time: 0.2098 data: 0.0157 max mem: 2759
Epoch: [0] [30/60] eta: 0:00:06 lr: 0.002629 loss: 0.5936 (1.0881) loss_classifier: 0.1046 (0.2662) loss_box_reg: 0.2762 (0.2534) loss_mask: 0.1992 (0.5460) loss_objectness: 0.0042 (0.0178) loss_rpn_box_reg: 0.0041 (0.0047) time: 0.2117 data: 0.0158 max mem: 2759
Epoch: [0] [40/60] eta: 0:00:04 lr: 0.003476 loss: 0.4908 (0.9438) loss_classifier: 0.0671 (0.2164) loss_box_reg: 0.2265 (0.2521) loss_mask: 0.1678 (0.4537) loss_objectness: 0.0051 (0.0161) loss_rpn_box_reg: 0.0056 (0.0054) time: 0.2133 data: 0.0168 max mem: 2759
Epoch: [0] [50/60] eta: 0:00:02 lr: 0.004323 loss: 0.4616 (0.8495) loss_classifier: 0.0546 (0.1871) loss_box_reg: 0.1976 (0.2361) loss_mask: 0.1711 (0.4063) loss_objectness: 0.0051 (0.0139) loss_rpn_box_reg: 0.0066 (0.0061) time: 0.2162 data: 0.0167 max mem: 2863
Epoch: [0] [59/60] eta: 0:00:00 lr: 0.005000 loss: 0.3605 (0.7698) loss_classifier: 0.0450 (0.1650) loss_box_reg: 0.1146 (0.2153) loss_mask: 0.1818 (0.3713) loss_objectness: 0.0024 (0.0121) loss_rpn_box_reg: 0.0066 (0.0061) time: 0.2098 data: 0.0150 max mem: 2863
Epoch: [0] Total time: 0:00:13 (0.2188 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:04 model_time: 0.0796 (0.0796) evaluator_time: 0.0029 (0.0029) time: 0.0939 data: 0.0108 max mem: 2863
Test: [49/50] eta: 0:00:00 model_time: 0.0416 (0.0582) evaluator_time: 0.0036 (0.0058) time: 0.0644 data: 0.0102 max mem: 2863
Test: Total time: 0:00:03 (0.0744 s / it)
Averaged stats: model_time: 0.0416 (0.0582) evaluator_time: 0.0036 (0.0058)
Accumulating evaluation results...
DONE (t=0.01s).
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.684
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.978
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.829
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.432
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.700
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.306
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.764
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.764
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.657
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.770
IoU metric: segm
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.669
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.978
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.880
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.409
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.690
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.282
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.724
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.724
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.643
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.729
Epoch: [1] [ 0/60] eta: 0:00:12 lr: 0.005000 loss: 0.3263 (0.3263) loss_classifier: 0.0701 (0.0701) loss_box_reg: 0.1182 (0.1182) loss_mask: 0.1341 (0.1341) loss_objectness: 0.0015 (0.0015) loss_rpn_box_reg: 0.0024 (0.0024) time: 0.2075 data: 0.0233 max mem: 2863
Epoch: [1] [10/60] eta: 0:00:11 lr: 0.005000 loss: 0.3105 (0.3021) loss_classifier: 0.0435 (0.0443) loss_box_reg: 0.0975 (0.1014) loss_mask: 0.1341 (0.1485) loss_objectness: 0.0013 (0.0020) loss_rpn_box_reg: 0.0036 (0.0059) time: 0.2218 data: 0.0173 max mem: 3233
Epoch: [1] [20/60] eta: 0:00:08 lr: 0.005000 loss: 0.3105 (0.3143) loss_classifier: 0.0435 (0.0441) loss_box_reg: 0.0975 (0.1026) loss_mask: 0.1612 (0.1593) loss_objectness: 0.0013 (0.0023) loss_rpn_box_reg: 0.0036 (0.0061) time: 0.2152 data: 0.0171 max mem: 3233
Epoch: [1] [30/60] eta: 0:00:06 lr: 0.005000 loss: 0.2894 (0.2990) loss_classifier: 0.0423 (0.0424) loss_box_reg: 0.0826 (0.0972) loss_mask: 0.1506 (0.1519) loss_objectness: 0.0008 (0.0020) loss_rpn_box_reg: 0.0042 (0.0056) time: 0.2079 data: 0.0165 max mem: 3233
Epoch: [1] [40/60] eta: 0:00:04 lr: 0.005000 loss: 0.2337 (0.2834) loss_classifier: 0.0383 (0.0409) loss_box_reg: 0.0672 (0.0904) loss_mask: 0.1217 (0.1447) loss_objectness: 0.0007 (0.0019) loss_rpn_box_reg: 0.0035 (0.0055) time: 0.2063 data: 0.0149 max mem: 3233
Epoch: [1] [50/60] eta: 0:00:02 lr: 0.005000 loss: 0.2160 (0.2717) loss_classifier: 0.0297 (0.0389) loss_box_reg: 0.0580 (0.0850) loss_mask: 0.1183 (0.1406) loss_objectness: 0.0007 (0.0018) loss_rpn_box_reg: 0.0036 (0.0054) time: 0.2046 data: 0.0142 max mem: 3233
Epoch: [1] [59/60] eta: 0:00:00 lr: 0.005000 loss: 0.2160 (0.2682) loss_classifier: 0.0288 (0.0384) loss_box_reg: 0.0578 (0.0817) loss_mask: 0.1196 (0.1406) loss_objectness: 0.0009 (0.0022) loss_rpn_box_reg: 0.0036 (0.0052) time: 0.2044 data: 0.0147 max mem: 3233
Epoch: [1] Total time: 0:00:12 (0.2084 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:02 model_time: 0.0398 (0.0398) evaluator_time: 0.0023 (0.0023) time: 0.0532 data: 0.0107 max mem: 3233
Test: [49/50] eta: 0:00:00 model_time: 0.0400 (0.0402) evaluator_time: 0.0025 (0.0039) time: 0.0548 data: 0.0102 max mem: 3233
Test: Total time: 0:00:02 (0.0544 s / it)
Averaged stats: model_time: 0.0400 (0.0402) evaluator_time: 0.0025 (0.0039)
Accumulating evaluation results...
DONE (t=0.01s).
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.819
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.990
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.945
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.487
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.827
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.362
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.854
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.854
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.857
IoU metric: segm
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.757
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.990
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.935
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.359
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.773
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.331
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.794
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.794
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.657
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.802
That's it!
因此,经过一个 epoch 的训练后,我们获得了超过 50 的 COCO 风格 mAP,以及 65 的掩码 mAP。
但预测结果是什么样的呢?让我们取数据集中的一张图像并进行验证。
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
image = read_image("data/PennFudanPed/PNGImages/FudanPed00046.png")
eval_transform = get_transform(train=False)
model.eval()
with torch.no_grad():
x = eval_transform(image)
# convert RGBA -> RGB and move to device
x = x[:3, ...].to(device)
predictions = model([x, ])
pred = predictions[0]
image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
image = image[:3, ...]
pred_labels = [f"pedestrian: {score:.3f}" for label, score in zip(pred["labels"], pred["scores"])]
pred_boxes = pred["boxes"].long()
output_image = draw_bounding_boxes(image, pred_boxes, pred_labels, colors="red")
masks = (pred["masks"] > 0.7).squeeze(1)
output_image = draw_segmentation_masks(output_image, masks, alpha=0.5, colors="blue")
plt.figure(figsize=(12, 12))
plt.imshow(output_image.permute(1, 2, 0))

<matplotlib.image.AxesImage object at 0x7f8953ee74c0>
结果看起来不错!
总结#
在本教程中,您学习了如何为自定义数据集上的对象检测模型创建自己的训练管道。为此,您编写了一个 torch.utils.data.Dataset 类,该类返回图像和真实边界框及分割掩码。您还利用了在 COCO train2017 上预训练的 Mask R-CNN 模型,以便在新数据集上执行迁移学习。
有关更完整的示例,包括多机/多 GPU 训练,请查看 torchvision 存储库中的 references/detection/train.py。
脚本总运行时间: (0 分 47.300 秒)