如何编写自己的 TVTensor 类¶
本指南面向高级用户和下游库维护者。我们将解释如何编写自己的 TVTensor 类,以及如何使其与内置的 Torchvision v2 变换兼容。在继续之前,请确保您已阅读 TVTensors FAQ。
import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
我们将创建一个非常简单的类,它仅继承自基类 TVTensor
。这足以涵盖您需要了解的内容,以实现更复杂的用例。如果您需要创建一个承载元数据的类,请查看 BoundingBoxes
类的 实现方式。
class MyTVTensor(tv_tensors.TVTensor):
pass
my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])
现在我们已经定义了自定义 TVTensor 类,我们希望它与内置的 torchvision 变换和函数式 API 兼容。为此,我们需要实现一个执行变换核心操作的内核,然后通过 register_kernel()
将其“挂钩”到我们想要支持的函数式 API 上。
我们在下面说明了这个过程:我们为 MyTVTensor 类的“水平翻转”操作创建一个内核,并将其注册到函数式 API。
from torchvision.transforms.v2 import functional as F
@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
要了解为什么使用 wrap()
,请参阅 我曾拥有一个 TVTensor,但现在我拥有一个 Tensor。救命!。暂时忽略 *args, **kwargs
,我们将在下面的 参数转发,以及确保内核的未来兼容性 中进行解释。
注意
在我们上面的 register_kernel
调用中,我们使用字符串 functional="hflip"
来引用我们想要挂钩的函数式 API。我们也可以直接使用函数式 API 本身,即 @register_kernel(functional=F.hflip, ...)
。
现在我们已经注册了我们的内核,我们可以对 MyTVTensor
实例调用函数式 API。
my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!
我们也可以使用 RandomHorizontalFlip
变换,因为它在内部依赖于 hflip()
。
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!
注意
我们不能为变换类注册内核,只能为函数式 API 注册内核。我们不能注册变换类是因为一个变换可能在内部依赖于多个函数式 API,所以一般情况下我们不能为给定的类注册一个单独的内核。
参数转发,以及确保内核的未来兼容性¶
您正在挂钩的函数式 API 是公共的,因此是向后兼容的:我们保证这些函数式 API 的参数不会在没有适当弃用周期的情况下被移除或重命名。然而,我们不保证向前兼容性,并且我们可能会在将来添加新参数。
设想一下,在未来的某个版本中,Torchvision 向其 hflip()
函数式 API 添加了一个新的 inplace
参数。如果您已经定义并注册了自己的内核,如下所示:
def hflip_my_tv_tensor(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
那么调用 F.hflip(my_dp)
将会失败,因为 hflip
将会尝试将新的 inplace
参数传递给您的内核,但您的内核不接受它。
因此,我们建议始终在内核的签名中定义 *args, **kwargs
,如上所示。这样,您的内核将能够接受我们将来可能添加的任何新参数。(技术上来说,只添加 **kwargs 应该就足够了)。
脚本总运行时间: (0 分钟 0.004 秒)