评价此页

编写自定义数据集、数据加载器和转换#

创建于:2017年6月10日 | 最后更新:2025年3月11日 | 最后验证:2024年11月5日

作者: Sasank Chilamkurthy

解决任何机器学习问题的大量工作都投入到数据准备中。PyTorch 提供了许多工具来简化数据加载,并希望使您的代码更具可读性。在本教程中,我们将学习如何从一个非平凡的数据集中加载和预处理/增强数据。

要运行本教程,请确保已安装以下软件包:

  • scikit-image:用于图像I/O和转换

  • pandas:用于更方便地解析 CSV

import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode
<contextlib.ExitStack object at 0x7efc581d41c0>

我们将要处理的数据集是面部姿态数据集。这意味着一张脸被标注如下:

../_images/landmarked_face2.png

总的来说,每张脸都标注了 68 个不同的特征点。

注意

请从此处下载数据集,以便图像位于名为 'data/faces/' 的目录中。这个数据集实际上是通过在一些被标记为'face'的 imagenet 图像上应用出色的 dlib 的姿态估计生成的。

数据集附带一个.csv文件,其中的标注如下所示:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

让我们从 CSV 中取一个图像名称及其标注,这里以第 65 行的 person-7.jpg 为例。读取它,将图像名称存储在 img_name 中,并将其标注存储在一个 (L, 2) 的数组 landmarks 中,其中 L 是该行中特征点的数量。

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:]
landmarks = np.asarray(landmarks, dtype=float).reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]

让我们编写一个简单的辅助函数来显示一张图片及其特征点,并用它来展示一个样本。

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
               landmarks)
plt.show()
data loading tutorial

数据集类#

torch.utils.data.Dataset 是一个表示数据集的抽象类。您的自定义数据集应继承 Dataset 并重写以下方法:

  • __len__ 以便 len(dataset) 返回数据集的大小。

  • __getitem__ 以支持索引,这样 dataset[i] 就可以用来获取第 \(i\) 个样本。

让我们为我们的面部特征点数据集创建一个数据集类。我们将在 __init__ 中读取 csv 文件,但将图像的读取留给 __getitem__。这样做是内存高效的,因为所有图像不是一次性存储在内存中,而是按需读取。

我们数据集的样本将是一个字典 {'image': image, 'landmarks': landmarks}。我们的数据集将接受一个可选参数 transform,以便可以对样本应用任何所需的处理。我们将在下一节中看到 transform 的用处。

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks], dtype=float).reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

让我们实例化这个类并遍历数据样本。我们将打印前4个样本的大小并显示它们的特征点。

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')

fig = plt.figure()

for i, sample in enumerate(face_dataset):
    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break
Sample #0, Sample #1, Sample #2, Sample #3
0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

转换#

从上面我们可以看到一个问题,即样本的大小并不相同。大多数神经网络期望图像具有固定的大小。因此,我们需要编写一些预处理代码。让我们创建三个转换:

  • Rescale:缩放图像

  • RandomCrop:从图像中随机裁剪。这是一种数据增强。

  • ToTensor:将 numpy 图像转换为 torch 图像(我们需要交换轴)。

我们将它们写成可调用类的形式,而不是简单的函数,这样每次调用转换时就不需要传递参数。为此,我们只需要实现 __call__ 方法,如果需要的话,还可以实现 __init__ 方法。然后我们可以像这样使用转换:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

请观察下面这些转换是如何必须同时应用于图像和特征点的。

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h + 1)
        left = np.random.randint(0, w - new_w + 1)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

注意

在上面的例子中,RandomCrop 使用了外部库的随机数生成器(在这种情况下是 Numpy 的 np.random.int)。这可能会导致 DataLoader 出现意外行为(请参阅此处)。在实践中,坚持使用 PyTorch 的随机数生成器更安全,例如使用 torch.randint

组合转换#

现在,我们将这些转换应用于一个样本。

假设我们想将图像的短边缩放到 256,然后从中随机裁剪出一个 224x224 的正方形。也就是说,我们想组合 RescaleRandomCrop 两种转换。torchvision.transforms.Compose 是一个简单的可调用类,可以让我们做到这一点。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()
Rescale, RandomCrop, Compose

遍历数据集#

让我们把所有这些组合起来,创建一个带有组合转换的数据集。总结一下,每次对这个数据集进行采样时:

  • 动态地从文件中读取一张图像

  • 对读取的图像应用转换

  • 由于其中一个转换是随机的,因此在采样时会对数据进行增强

我们可以像之前一样使用 for i in range 循环来遍历创建的数据集。

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i, sample in enumerate(transformed_dataset):
    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

然而,通过使用简单的 for 循环来遍历数据,我们失去了很多特性。特别是,我们错过了:

  • 批处理数据

  • 打乱数据

  • 使用 multiprocessing 工作进程并行加载数据。

torch.utils.data.DataLoader 是一个提供所有这些功能的迭代器。下面使用的参数应该很清楚。其中一个值得关注的参数是 collate_fn。您可以使用 collate_fn 来精确指定样本需要如何批处理。然而,对于大多数用例来说,默认的 collate 函数应该可以正常工作。

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=0)


# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size,
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

# if you are using Windows, uncomment the next line and indent the for loop.
# you might need to go back and change ``num_workers`` to 0.

# if __name__ == '__main__':
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break
Batch from dataloader
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

后记:torchvision#

在本教程中,我们学习了如何编写和使用数据集、转换和数据加载器。torchvision 包提供了一些常见的数据集和转换。你甚至可能不需要编写自定义类。torchvision 中一个更通用的数据集是 ImageFolder。它假定图像按以下方式组织:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

其中 ‘ants’、‘bees’ 等是类别标签。同样,也提供了对 PIL.Image 进行操作的通用转换,例如 RandomHorizontalFlipScale。你可以使用这些来编写一个像这样的数据加载器:

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

有关带训练代码的示例,请参阅计算机视觉迁移学习教程

脚本总运行时间: (0 分 1.944 秒)