快捷方式

如何使用 CutMix 和 MixUp

注意

Colab 上尝试,或 前往文末 下载完整的示例代码。

CutMixMixUp 是流行的增强策略,可以提高分类准确性。

这些转换与 Torchvision 的其他转换略有不同,因为它们期望输入的是样本的批次,而不是单个图像。在此示例中,我们将解释如何使用它们:在 DataLoader 之后,或作为 collate 函数的一部分。

import torch
from torchvision.datasets import FakeData
from torchvision.transforms import v2


NUM_CLASSES = 100

预处理流程

我们将使用一个简单但典型的图像分类流程

preproc = v2.Compose([
    v2.PILToTensor(),
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),  # to float32 in [0, 1]
    v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # typically from ImageNet
])

dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc)

img, label = dataset[0]
print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }")
type(img) = <class 'torch.Tensor'>, img.dtype = torch.float32, img.shape = torch.Size([3, 224, 224]), label = 67

需要注意的一点是,CutMix 和 MixUp 都不属于此预处理流程。我们将在定义 DataLoader 后稍后添加它们。仅作为回顾,如果我们不使用 CutMix 或 MixUp,DataLoader 和训练循环会是这样的

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

for images, labels in dataloader:
    print(f"{images.shape = }, {labels.shape = }")
    print(labels.dtype)
    # <rest of the training loop here>
    break
images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4])
torch.int64

在哪里使用 MixUp 和 CutMix

在 DataLoader 之后

现在让我们添加 CutMix 和 MixUp。最简单的方法是在 DataLoader 之后立即执行此操作:DataLoader 已经为我们批处理了图像和标签,而这正是这些转换所期望的输入

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

cutmix = v2.CutMix(num_classes=NUM_CLASSES)
mixup = v2.MixUp(num_classes=NUM_CLASSES)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

for images, labels in dataloader:
    print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }")
    images, labels = cutmix_or_mixup(images, labels)
    print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }")

    # <rest of the training loop here>
    break
Before CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4])
After CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4, 100])

请注意标签也已转换:我们从形状为 (batch_size,) 的批处理标签转换为形状为 (batch_size, num_classes) 的张量。转换后的标签仍可按原样传递给像 torch.nn.functional.cross_entropy() 这样的损失函数。

作为 collate 函数的一部分

在 DataLoader 之后传递转换是最简单的使用 CutMix 和 MixUp 的方法,但一个缺点是它没有利用 DataLoader 的多进程。为此,我们可以将这些转换作为 collate 函数的一部分传递(请参阅 PyTorch 文档 了解有关 collate 的更多信息)。

from torch.utils.data import default_collate


def collate_fn(batch):
    return cutmix_or_mixup(*default_collate(batch))


dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)

for images, labels in dataloader:
    print(f"{images.shape = }, {labels.shape = }")
    # No need to call cutmix_or_mixup, it's already been called as part of the DataLoader!
    # <rest of the training loop here>
    break
images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4, 100])

非标准输入格式

到目前为止,我们使用了一个典型的样本结构,我们将 (images, labels) 作为输入。MixUp 和 CutMix 默认情况下可以与大多数常见的样本结构神奇地协同工作:作为元组,其中第二个参数是张量标签;或者作为字典,其中包含“label[s]”键。有关更多详细信息,请查看 labels_getter 参数的文档。

如果您的样本具有不同的结构,您仍然可以通过将可调用对象传递给 labels_getter 参数来使用 CutMix 和 MixUp。例如

batch = {
    "imgs": torch.rand(4, 3, 224, 224),
    "target": {
        "classes": torch.randint(0, NUM_CLASSES, size=(4,)),
        "some_other_key": "this is going to be passed-through"
    }
}


def labels_getter(batch):
    return batch["target"]["classes"]


out = v2.CutMix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch)
print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }")
out['imgs'].shape = torch.Size([4, 3, 224, 224]), out['target']['classes'].shape = torch.Size([4, 100])

脚本总运行时间: (0 分 0.186 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源