TVTensors FAQ¶
TVTensors 是与 torchvision.transforms.v2
一起引入的 Tensor 子类。本示例展示了这些 TVTensors 是什么以及它们的行为方式。
警告
目标读者 除非您正在编写自己的 transforms 或自己的 TVTensors,否则您可能不需要阅读本指南。这是一个相当底层的 Topics,大多数用户无需担心:您无需了解 TVTensors 的内部原理即可有效地依赖 torchvision.transforms.v2
。但对于尝试实现自己的数据集、transforms 或直接使用 TVTensors 的高级用户可能有用。
import PIL.Image
import torch
from torchvision import tv_tensors
什么是 TVTensors?¶
TVTensors 是零拷贝的 Tensor 子类
tensor = torch.rand(3, 256, 256)
image = tv_tensors.Image(tensor)
assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()
在底层,torchvision.transforms.v2
需要它们来正确地分派到适用于输入数据的相应函数。
torchvision.tv_tensors
支持五种类型的 TVTensors
我可以用 TVTensor 做什么?¶
TVTensors 的外观和感觉就像常规 Tensor 一样——它们就是 Tensor。支持常规 torch.Tensor
的所有操作,例如 .sum()
或任何 torch.*
操作符,同样适用于 TVTensors。有关一些注意事项,请参阅 我有一个 TVTensor,但现在我得到了一个 Tensor。求助!。
如何构造 TVTensor?¶
使用构造函数¶
每个 TVTensor 类都可以接受任何可以转换为 Tensor
的 tensor 类型数据
image = tv_tensors.Image([[[[0, 1], [1, 0]]]])
print(image)
Image([[[[0, 1],
[1, 0]]]], )
与其他的 PyTorch 创建操作类似,构造函数还接受 dtype
、device
和 requires_grad
参数。
float_image = tv_tensors.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image)
Image([[[0., 1.],
[1., 0.]]], grad_fn=<AliasBackward0>, )
此外,Image
和 Mask
也可以直接接受 PIL.Image.Image
image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
print(image.shape, image.dtype)
torch.Size([3, 512, 512]) torch.uint8
某些 TVTensors 需要传递额外的元数据才能构造。例如,BoundingBoxes
需要坐标格式以及相应图像的大小(canvas_size
)以及实际值。这些元数据对于正确转换边界框是必需的。类似地,KeyPoints
也需要添加 canvas_size
元数据。
bboxes = tv_tensors.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]],
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=image.shape[-2:]
)
print(bboxes)
keypoints = tv_tensors.KeyPoints(
[[17, 16], [344, 495], [0, 10], [0, 10]],
canvas_size=image.shape[-2:]
)
print(keypoints)
BoundingBoxes([[ 17, 16, 344, 495],
[ 0, 10, 0, 10]], format=BoundingBoxFormat.XYXY, canvas_size=torch.Size([512, 512]), clamping_mode=soft)
KeyPoints([[ 17, 16],
[344, 495],
[ 0, 10],
[ 0, 10]], canvas_size=torch.Size([512, 512]))
使用 tv_tensors.wrap()
¶
您还可以使用 wrap()
函数将 tensor 对象包装到 TVTensor 中。当您已经拥有所需类型的对象时,这很有用,这通常发生在编写 transforms 时:您只想像输入一样包装输出。
new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size
new_bboxes
的元数据与 bboxes
相同,但您可以将其作为参数传递以覆盖它。
我有一个 TVTensor,但现在我得到了一个 Tensor。求助!¶
默认情况下,对 TVTensor
对象的操作将返回一个纯 Tensor
assert isinstance(bboxes, tv_tensors.BoundingBoxes)
# Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3
assert isinstance(new_bboxes, torch.Tensor)
assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)
注意
此行为仅影响原生的 torch
操作。如果您使用的是内置的 torchvision
transforms 或 functionals,您将始终获得与输入相同的输出类型(纯 Tensor
或 TVTensor
)。
但我想要一个 TVTensor!¶
您可以通过调用 TVTensor 构造函数或使用 wrap()
函数(请参阅上面 如何构造 TVTensor? 中的更多详细信息)来将纯 Tensor 重新包装为 TVTensor。
new_bboxes = bboxes + 3
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
或者,您可以使用 set_return_type()
作为整个程序的全局配置设置,或作为上下文管理器(阅读其文档以了解更多关于注意事项)。
with tv_tensors.set_return_type("TVTensor"):
new_bboxes = bboxes + 3
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
为什么会发生这种情况?¶
出于性能原因。TVTensor
类是 Tensor 子类,因此任何涉及 TVTensor
对象的都会通过 __torch_function__ 协议。这会产生一些开销,我们希望尽可能避免。这对于内置的 torchvision
transforms 来说无关紧要,因为我们可以避免那里的开销,但它可能会成为您模型 forward
中的问题。
另一种选择也并没有好多少。 对于每一个保留 TVTensor
类型有意义的操作,都有同样多的操作返回纯 Tensor 是可取的:例如,img.sum()
仍然是 Image
吗?如果我们一直保留 TVTensor
类型,即使是模型的 logits 或损失函数的输出最终也会是 Image
类型,这肯定是不理想的。
注意
这是我们正在积极征求反馈的行为。如果您觉得这令人惊讶,或者对如何更好地支持您的用例有任何建议,请通过此问题与我们联系:https://github.com/pytorch/vision/issues/7319
例外¶
这条“解包装”规则有一些例外:clone()
、to()
、torch.Tensor.detach()
和 requires_grad_()
会保留 TVTensor 类型。
TVTensors 上的原地操作,例如 obj.add_()
,将保留 obj
的类型。但是,原地操作的返回值将是纯 Tensor。
image = tv_tensors.Image([[[0, 1], [1, 0]]])
new_image = image.add_(1).mul_(2)
# image got transformed in-place and is still a TVTensor Image, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert isinstance(image, tv_tensors.Image)
print(image)
assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, tv_tensors.Image)
assert (new_image == image).all()
assert new_image.data_ptr() == image.data_ptr()
Image([[[2, 4],
[4, 2]]], )
脚本总运行时间:(0 分钟 0.008 秒)