Unfold#
- class torch.nn.modules.fold.Unfold(kernel_size, dilation=1, padding=0, stride=1)[源码]#
从批处理的输入 Tensor 中提取滑动局部块。
考虑一个形状为 的批量输入张量,其中 是批次维度, 是通道维度, 代表任意空间维度。此操作会将输入张量空间维度中每个
kernel_size
大小的滑动块展平成一个列(即最后一个维度),得到一个形状为 的 3-D 输出张量,其中 是每个块内的值总数(一个块有 个空间位置,每个位置包含一个 通道的向量),并且 是这些块的总数。其中 由输入张量的空间维度组成(上面是 ),并且 遍历所有空间维度。
因此,索引输出张量的最后一个维度(列维度)将获得某个块内的所有值。
参数
padding
、stride
和dilation
指定了如何提取滑动块。stride
控制滑动块的步幅。padding
控制在重塑之前,每个维度上的padding
个点两侧的隐式零填充量。dilation
控制核点之间的间距;也称为空洞卷积算法。这个概念比较难描述,但 这个链接 有一个dilation
作用的可视化。
- 参数
如果
kernel_size
、dilation
、padding
或stride
是一个整数或长度为 1 的元组,其值将复制到所有空间维度。对于具有两个输入空间维度的场景,此操作有时称为
im2col
。
注意
Fold
通过对所有包含块的值求和来计算结果大张量中的每个组合值。Unfold
通过从大张量中复制来提取局部块中的值。因此,如果块重叠,它们不是彼此的逆操作。一般来说,折叠和展开操作相关如下。考虑使用相同参数创建的
Fold
和Unfold
实例。>>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) >>> fold = nn.Fold(output_size=..., **fold_params) >>> unfold = nn.Unfold(**fold_params)
那么对于任何(受支持的)
input
张量,以下等式成立:fold(unfold(input)) == divisor * input
其中
divisor
是一个仅取决于input
的形状和 dtype 的张量。>>> input_ones = torch.ones(input.shape, dtype=input.dtype) >>> divisor = fold(unfold(input_ones))
当
divisor
张量不包含零元素时,则fold
和unfold
操作是彼此的逆运算( up to constant divisor)。警告
目前,仅支持 4 维输入张量(批处理的类图像张量)。
- 形状
输入:
输出: ,如上所述。
示例
>>> unfold = nn.Unfold(kernel_size=(2, 3)) >>> input = torch.randn(2, 5, 3, 4) >>> output = unfold(input) >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels) >>> # 4 blocks (2x3 kernels) in total in the 3x4 input >>> output.size() torch.Size([2, 30, 4]) >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape) >>> inp = torch.randn(1, 3, 10, 12) >>> w = torch.randn(2, 3, 4, 5) >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5)) >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2) >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1)) >>> # or equivalently (and avoiding a copy), >>> # out = out_unf.view(1, 2, 7, 8) >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max() tensor(1.9073e-06)