快捷方式

如何编写你自己的 v2 变换

注意

Colab 上试用,或前往底部下载完整的示例代码。

本指南介绍了如何编写与 torchvision 变换 V2 API 兼容的变换。

from typing import Any, Dict, List

import torch
from torchvision import tv_tensors
from torchvision.transforms import v2

只需创建一个 nn.Module 并重写 forward 方法

在大多数情况下,只要你已经知道变换将期望的输入结构,这就足够了。例如,如果你只做图像分类,你的变换通常会接受单个图像作为输入,或者接受 (img, label) 作为输入。所以你可以硬编码你的 forward 方法来只接受这些,例如:

class MyCustomTransform(torch.nn.Module):
    def forward(self, img, label):
        # Do some transformations
        return new_img, new_label

注意

这意味着,如果你有一个自定义变换,它已经与 V1 变换(在 torchvision.transforms 中)兼容,那么它无需任何更改即可与 V2 变换一起使用!

我们将通过一个典型的检测案例来更完整地说明这一点,在该案例中,我们的样本只是图像、边界框和标签。

class MyCustomTransform(torch.nn.Module):
    def forward(self, img, bboxes, label):  # we assume inputs are always structured like this
        print(
            f"I'm transforming an image of shape {img.shape} "
            f"with bboxes = {bboxes}\n{label = }"
        )
        # Do some transformations. Here, we're just passing though the input
        return img, bboxes, label


transforms = v2.Compose([
    MyCustomTransform(),
    v2.RandomResizedCrop((224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=1),
    v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
])

H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = tv_tensors.BoundingBoxes(
    torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
    format="XYXY",
    canvas_size=(H, W)
)
label = 3

out_img, out_bboxes, out_label = transforms(img, bboxes, label)
I'm transforming an image of shape torch.Size([3, 256, 256]) with bboxes = BoundingBoxes([[ 0, 10, 10, 20],
               [50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)
label = 3
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
Output image shape: torch.Size([3, 224, 224])
out_bboxes = BoundingBoxes([[221,   0, 224,   0],
               [114,  27, 150,  61]], format=BoundingBoxFormat.XYXY, canvas_size=(224, 224), clamping_mode=soft)
out_label = 3

注意

在使用 TVTensor 类进行编码时,请确保熟悉本节:我有一个 TVTensor,但现在我得到了一个 Tensor。救命!

支持任意输入结构

在上一节中,我们假设你已经知道输入的结构,并且你愿意在代码中硬编码这个期望的结构。如果你希望你的自定义变换尽可能灵活,这可能会有些限制。

内置的 Torchvision V2 变换的一个关键特性是它们可以接受任意输入结构并返回相同的结构作为输出(经过变换的条目)。例如,变换可以接受单个图像,或 (img, label) 的元组,或者嵌套的字典作为输入。下面是一个关于内置变换 RandomHorizontalFlip 的示例。

structured_input = {
    "img": img,
    "annotations": (bboxes, label),
    "something that will be ignored": (1, "hello"),
    "another tensor that is ignored": torch.arange(10),
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
The input bboxes are:
BoundingBoxes([[ 0, 10, 10, 20],
               [50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)
The transformed bboxes are:
BoundingBoxes([[246,  10, 256,  20],
               [186,  50, 206,  70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)

基础:重写 transform() 方法

为了在你的自定义变换中支持任意输入,你需要继承自 Transform 并重写 .transform() 方法(而不是 forward() 方法!)。下面是一个基本示例:

class MyCustomTransform(v2.Transform):
    def transform(self, inpt: Any, params: Dict[str, Any]):
        if type(inpt) == torch.Tensor:
            print(f"I'm transforming an image of shape {inpt.shape}")
            return inpt + 1  # dummy transformation
        elif isinstance(inpt, tv_tensors.BoundingBoxes):
            print(f"I'm transforming bounding boxes! {inpt.canvas_size = }")
            return tv_tensors.wrap(inpt + 100, like=inpt)  # dummy transformation


my_custom_transform = MyCustomTransform()
structured_output = my_custom_transform(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
I'm transforming an image of shape torch.Size([3, 256, 256])
I'm transforming bounding boxes! inpt.canvas_size = (256, 256)
The input bboxes are:
BoundingBoxes([[ 0, 10, 10, 20],
               [50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)
The transformed bboxes are:
BoundingBoxes([[100, 110, 110, 120],
               [150, 150, 170, 170]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)

需要注意的一个重要事项是,当我们对 structured_input 调用 my_custom_transform 时,输入会被展平,然后每个单独的部分会被传递给 transform()。也就是说,transform() 接收输入图像,然后是边界框,等等。在 transform() 内部,你可以根据输入的类型决定如何变换每个输入。

如果你好奇为什么另一个张量(torch.arange())没有被传递给 transform(),请参阅此注释了解更多详细信息。

高级:make_params() 方法

在调用 transform() 对每个输入进行处理之前,make_params() 方法会被内部调用。这通常对于生成随机参数值很有用。在下面的示例中,我们使用它以 0.5 的概率随机应用变换。

class MyRandomTransform(MyCustomTransform):
    def __init__(self, p=0.5):
        self.p = p
        super().__init__()

    def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        apply_transform = (torch.rand(size=(1,)) < self.p).item()
        params = dict(apply_transform=apply_transform)
        return params

    def transform(self, inpt: Any, params: Dict[str, Any]):
        if not params["apply_transform"]:
            print("Not transforming anything!")
            return inpt
        else:
            return super().transform(inpt, params)


my_random_transform = MyRandomTransform()

torch.manual_seed(0)
_ = my_random_transform(structured_input)  # transforms
_ = my_random_transform(structured_input)  # doesn't transform
I'm transforming an image of shape torch.Size([3, 256, 256])
I'm transforming bounding boxes! inpt.canvas_size = (256, 256)
Not transforming anything!
Not transforming anything!

注意

对于这种随机参数生成,重要的是它发生在 make_params() 内部,而不是 transform() 内部,这样对于给定的变换调用,相同的 RNG 会以相同的方式应用于所有输入。如果我们要在 transform() 中执行 RNG,我们可能会面临例如在变换图像时变换边界框的风险。

make_params() 方法将所有输入的列表作为参数(此列表中的每个元素稍后都会传递给 transform())。你可以使用 flat_inputs 来例如确定输入的维度,使用 query_chw()query_size()

make_params() 应该返回一个字典(或者实际上,任何你想要的东西),它将被传递给 transform()

脚本总运行时间: (0 分钟 0.009 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源