如何编写自己的 v2 变换¶
本指南将介绍如何编写与 torchvision transforms V2 API 兼容的变换。
只需创建一个 nn.Module
并重写 forward
方法¶
在大多数情况下,只要您了解变换将期望的输入结构,这就可以了。例如,如果您仅进行图像分类,您的变换通常会接受单个图像作为输入,或者接受 (img, label)
作为输入。因此,您可以直接将 forward
方法硬编码为仅接受该输入,例如:
class MyCustomTransform(torch.nn.Module):
def forward(self, img, label):
# Do some transformations
return new_img, new_label
注意
这意味着,如果您有一个自定义变换已兼容 V1 变换(位于 torchvision.transforms
中),它无需任何更改即可与 V2 变换一起使用!
下面我们将更全面地说明这一点,以一个典型的检测案例为例,其中我们的样本仅为图像、边界框和标签。
class MyCustomTransform(torch.nn.Module):
def forward(self, img, bboxes, label): # we assume inputs are always structured like this
print(
f"I'm transforming an image of shape {img.shape} "
f"with bboxes = {bboxes}\n{label = }"
)
# Do some transformations. Here, we're just passing though the input
return img, bboxes, label
transforms = v2.Compose([
MyCustomTransform(),
v2.RandomResizedCrop((224, 224), antialias=True),
v2.RandomHorizontalFlip(p=1),
v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
])
H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = tv_tensors.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY",
canvas_size=(H, W)
)
label = 3
out_img, out_bboxes, out_label = transforms(img, bboxes, label)
I'm transforming an image of shape torch.Size([3, 256, 256]) with bboxes = BoundingBoxes([[ 0, 10, 10, 20],
[50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)
label = 3
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
Output image shape: torch.Size([3, 224, 224])
out_bboxes = BoundingBoxes([[218, 7, 224, 16],
[148, 43, 171, 61]], format=BoundingBoxFormat.XYXY, canvas_size=(224, 224), clamping_mode=soft)
out_label = 3
注意
在您的代码中使用 TVTensor 类时,请确保您熟悉本节:我有一个 TVTensor,现在我有一个 Tensor。求助!
支持任意输入结构¶
在上一节中,我们假设您已经了解输入的结构,并且您可以在代码中硬编码此期望的结构。如果您希望自定义变换尽可能灵活,这可能有点局限。
内置 Torchvision V2 变换的一个关键特性是它们可以接受任意输入结构并返回相同的结构作为输出(经过变换的条目)。例如,变换可以接受单个图像,或者一个 (img, label)
元组,或者任意嵌套的字典作为输入。以下是内置变换 RandomHorizontalFlip
的示例:
structured_input = {
"img": img,
"annotations": (bboxes, label),
"something that will be ignored": (1, "hello"),
"another tensor that is ignored": torch.arange(10),
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)
assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
The input bboxes are:
BoundingBoxes([[ 0, 10, 10, 20],
[50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)
The transformed bboxes are:
BoundingBoxes([[246, 10, 256, 20],
[186, 50, 206, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)
基础:重写 transform() 方法¶
为了支持自定义变换中的任意输入,您需要继承自 Transform
并重写 .transform() 方法(而不是 forward() 方法!)。下面是一个基本示例:
class MyCustomTransform(v2.Transform):
def transform(self, inpt: Any, params: Dict[str, Any]):
if type(inpt) == torch.Tensor:
print(f"I'm transforming an image of shape {inpt.shape}")
return inpt + 1 # dummy transformation
elif isinstance(inpt, tv_tensors.BoundingBoxes):
print(f"I'm transforming bounding boxes! {inpt.canvas_size = }")
return tv_tensors.wrap(inpt + 100, like=inpt) # dummy transformation
my_custom_transform = MyCustomTransform()
structured_output = my_custom_transform(structured_input)
assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
I'm transforming an image of shape torch.Size([3, 256, 256])
I'm transforming bounding boxes! inpt.canvas_size = (256, 256)
The input bboxes are:
BoundingBoxes([[ 0, 10, 10, 20],
[50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)
The transformed bboxes are:
BoundingBoxes([[100, 110, 110, 120],
[150, 150, 170, 170]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)
需要注意的一个重要事项是,当我们对 structured_input
调用 my_custom_transform
时,输入会被展平,然后每个单独的部分会被传递到 transform()
。也就是说,transform()
依次接收输入图像、边界框等。在 transform()
中,您可以根据每个输入的类型来决定如何转换它们。
如果您好奇为什么另一个张量(torch.arange()
)没有被传递到 transform()
,请参阅此说明了解更多详情。
高级:make_params()
方法¶
在调用 transform()
对每个输入进行处理之前,内部会调用 make_params()
方法。这通常对于生成随机参数值很有用。在下面的示例中,我们使用它以 0.5 的概率随机应用变换:
class MyRandomTransform(MyCustomTransform):
def __init__(self, p=0.5):
self.p = p
super().__init__()
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
apply_transform = (torch.rand(size=(1,)) < self.p).item()
params = dict(apply_transform=apply_transform)
return params
def transform(self, inpt: Any, params: Dict[str, Any]):
if not params["apply_transform"]:
print("Not transforming anything!")
return inpt
else:
return super().transform(inpt, params)
my_random_transform = MyRandomTransform()
torch.manual_seed(0)
_ = my_random_transform(structured_input) # transforms
_ = my_random_transform(structured_input) # doesn't transform
I'm transforming an image of shape torch.Size([3, 256, 256])
I'm transforming bounding boxes! inpt.canvas_size = (256, 256)
Not transforming anything!
Not transforming anything!
注意
对于此类随机参数生成,在 make_params()
中进行而不是在 transform()
中进行非常重要,这样对于给定的变换调用,相同的 RNG 会以相同的方式应用于所有输入。如果我们要在 transform()
中执行 RNG,我们将面临风险,例如变换了图像,但没有变换边界框。
make_params()
方法接收所有输入列表作为参数(此列表中的每个元素稍后都会传递给 transform()
)。您可以使用 flat_inputs
来例如确定输入数据的维度,使用 query_chw()
或 query_size()
。
make_params()
应该返回一个字典(或者实际上,任何您想要的东西),然后该字典将被传递给 transform()
。
脚本总运行时间: (0 分钟 0.009 秒)