评价此页

ChannelShuffle#

class torch.nn.ChannelShuffle(groups)[source]#

分割并重新排列张量中的通道。

该操作将形状为 (N,C,)(N, C, *) 的张量中的通道划分为 g 个组,形式为 (N,Cg,g,)(N, \frac{C}{g}, g, *) 并对它们进行重排(shuffle),最终输出保持原始张量形状。

参数:

groups (int) – 将通道划分成的组数。

示例

>>> channel_shuffle = nn.ChannelShuffle(2)
>>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2)
>>> input
tensor([[[[ 1.,  2.],
          [ 3.,  4.]],
         [[ 5.,  6.],
          [ 7.,  8.]],
         [[ 9., 10.],
          [11., 12.]],
         [[13., 14.],
          [15., 16.]]]])
>>> output = channel_shuffle(input)
>>> output
tensor([[[[ 1.,  2.],
          [ 3.,  4.]],
         [[ 9., 10.],
          [11., 12.]],
         [[ 5.,  6.],
          [ 7.,  8.]],
         [[13., 14.],
          [15., 16.]]]])
extra_repr()[source]#

返回模块的额外表示。

返回类型:

str

forward(input)[source]#

执行前向传播。

返回类型:

张量