torch.nn.utils.convert_conv2d_weight_memory_format#
- torch.nn.utils.convert_conv2d_weight_memory_format(module, memory_format)[source]#
Convert
memory_format
ofnn.Conv2d.weight
tomemory_format
。该转换递归地应用于嵌套的
nn.Module
,包括module
。请注意,它仅更改 memory_format,而不更改每个维度的语义。此函数用于促进计算采用 NHWC 内核,这为 CUDA 设备上计算能力 >= 7.0 的 fp16 数据提供了可观的加速。注意
调用
model.to(memory_format=torch.channels_last)
比实用函数convert_conv2d_weight_memory_format
更具侵略性。任何具有 4d 权重的层都将受到model.to
的影响,这不一定有利于转换为指定的memory_format
。我们可以确信的一点是,cuDNN 中卷积的 NHWC (channels_last) 转换,因为即使在必须对输入张量进行置换以匹配 channels_last 的情况下,在 NHWC 中运行卷积也是有益的。因此,我们的策略是仅将卷积的权重转换为 channels_last。这确保了:1. 将使用快速卷积内核,其优势可能超过置换的开销(如果输入格式不兼容)。2. 不会对不需要 memory_format 转换的层应用不必要的置换。
最佳情况是,卷积层之间的层都与 channels last 兼容。当输入张量遇到第一个卷积层时,它将被置换为 channels last 格式,并保持该内存格式。因此,后续的卷积将不需要置换其输入张量。
在卷积层之间存在 channels last 不兼容的层的情况下,我们需要将输入张量置换回该层的连续格式。输入张量将以连续格式通过其余层,并在遇到另一个卷积层时被置换为 channels last。将该置换传播到更早的层是没有意义的,因为大多数层对
memory_format
都相当不敏感。当 PyTorch 支持置换的融合时,此声明可能会发生变化,因为可能存在比卷积层之前更好的融合置换的位置。
- 参数
module (nn.Module) –
nn.Conv2d
&nn.ConvTranspose2d
或容器nn.Module
memory_format (memory_format) – 用户指定的
memory_format
,例如torch.channels_last
或torch.contiguous_format
- 返回
具有更新的
nn.Conv2d
的原始模块- 返回类型
_M
示例
>>> input = torch.randint( ... 1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda" ... ) >>> model = nn.Sequential( >>> nn.Conv2d(8, 4, 3)).cuda().half() >>> # This is identical to: >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) >>> model = nn.utils.convert_conv2d_weight_memory_format( ... model, torch.channels_last ... ) >>> out = model(input)