评价此页

PyTorch 中的形状推理#

创建日期:2023年3月27日 | 最后更新:2023年3月27日 | 最后验证:未验证

在使用 PyTorch 编写模型时,通常情况下,给定层的参数取决于前一层输出的形状。例如,nn.Linear 层的 in_features 必须与输入的 size(-1) 相匹配。对于某些层,形状计算涉及复杂的等式,例如卷积操作。

一种解决方法是用随机输入运行前向传播,但这在内存和计算方面是一种浪费。

相反,我们可以利用 meta 设备来确定层的输出形状,而无需实例化任何数据。

import torch
import timeit

t = torch.rand(2, 3, 10, 10, device="meta")
conv = torch.nn.Conv2d(3, 5, 2, device="meta")
start = timeit.default_timer()
out = conv(t)
end = timeit.default_timer()

print(out)
print(f"Time taken: {end-start}")
tensor(..., device='meta', size=(2, 5, 9, 9), grad_fn=<ConvolutionBackward0>)
Time taken: 0.0003438769999775104

请注意,由于数据未被实例化,传递任意大的输入不会显著改变形状计算所需的时间。

t_large = torch.rand(2**10, 3, 2**16, 2**16, device="meta")
start = timeit.default_timer()
out = conv(t_large)
end = timeit.default_timer()

print(out)
print(f"Time taken: {end-start}")
tensor(..., device='meta', size=(1024, 5, 65535, 65535),
       grad_fn=<ConvolutionBackward0>)
Time taken: 0.00012858300010520907

考虑一个如下所示的任意网络

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

我们可以通过为每个层注册一个前向钩子(forward hook)来查看整个网络中的中间形状,该钩子会打印输出的形状。

def fw_hook(module, input, output):
    print(f"Shape of output to {module} is {output.shape}.")


# Any tensor created within this torch.device context manager will be
# on the meta device.
with torch.device("meta"):
    net = Net()
    inp = torch.randn((1024, 3, 32, 32))

for name, layer in net.named_modules():
    layer.register_forward_hook(fw_hook)

out = net(inp)
Shape of output to Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) is torch.Size([1024, 6, 28, 28]).
Shape of output to MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is torch.Size([1024, 6, 14, 14]).
Shape of output to Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) is torch.Size([1024, 16, 10, 10]).
Shape of output to MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is torch.Size([1024, 16, 5, 5]).
Shape of output to Linear(in_features=400, out_features=120, bias=True) is torch.Size([1024, 120]).
Shape of output to Linear(in_features=120, out_features=84, bias=True) is torch.Size([1024, 84]).
Shape of output to Linear(in_features=84, out_features=10, bias=True) is torch.Size([1024, 10]).
Shape of output to Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
) is torch.Size([1024, 10]).

脚本总运行时间: (0 分钟 0.016 秒)