评价此页

torch.nn.functional.affine_grid#

torch.nn.functional.affine_grid(theta, size, align_corners=None)[source]#

生成二维或三维的采样网格(flow field),输入为仿射矩阵批次 theta

注意

此函数通常与 grid_sample() 结合使用,以构建 空间变换网络

参数
  • theta (Tensor) – 输入的仿射矩阵批次,形状为 (N×2×3N \times 2 \times 3) (用于二维)或 (N×3×4N \times 3 \times 4) (用于三维)。

  • size (torch.Size) – 目标输出图像尺寸。对于二维,形状为 (N×C×H×WN \times C \times H \times W;对于三维,形状为 N×C×D×H×WN \times C \times D \times H \times W)。例如:torch.Size((32, 3, 24, 24))

  • align_corners (bool, optional) – 如果为 True,则将 -11 视为角像素的中心,而不是图像的角。有关更多描述,请参阅 grid_sample()。应将 affine_grid() 生成的网格与此选项相同的设置一起传递给 grid_sample()。默认为 False

返回

输出张量,形状为 (N×H×W×2N \times H \times W \times 2)

返回类型

输出(Tensor

警告

align_corners = True时,网格位置取决于像素大小相对于输入图像大小,因此对于在不同分辨率(即经过上采样或下采样后)提供的相同输入,grid_sample()采样的位置将不同。在1.2.0版本之前的默认行为是align_corners = True。此后,为了与interpolate()的默认行为保持一致,默认行为已更改为align_corners = False

警告

align_corners = True时,1D数据上的2D仿射变换和2D数据上的3D仿射变换(即,当一个空间维度的大小为单位时)是定义不明确的,并且不是预期的使用场景。当align_corners = False时,这不是问题。在1.2.0版本之前,单位维度上的所有网格点被任意地视为-1。从1.3.0版本开始,在align_corners = True下,单位维度上的所有网格点被视为0(输入图像的中心)。