评价此页

torch.nn.utils.convert_conv3d_weight_memory_format#

torch.nn.utils.convert_conv3d_weight_memory_format(module, memory_format)[源代码]#

nn.Conv3d.weightmemory_format 转换为指定的 memory_format。该转换会递归地应用于嵌套的 nn.Module,包括 module 本身。请注意,此函数仅更改 memory_format,而不改变每个维度的语义。此函数用于促进计算以采用 NHWC 内核,这可以为计算能力 >= 7.0 的 CUDA 设备上的 fp16 数据提供相当大的加速。

注意

调用 model.to(memory_format=torch.channels_last_3d) 比实用函数 convert_conv3d_weight_memory_format 更具侵入性。任何具有 4d 权重的层都会受到 model.to 的影响,这些层不一定能从转换为指定的 memory_format 中受益。我们有信心的一个领域是 cuDNN 中卷积的 NDHWC (channels_last_3d) 转换,因为在 NDHWC 中运行卷积是有益的,即使在需要对输入张量进行置换的情况下也是如此。

因此,我们的策略是仅将卷积的权重转换为 channels_last_3d。这确保了:1. 会使用快速卷积内核,其好处可能超过置换的开销(如果输入格式不匹配)。2. 不会对不受益于 memory_format 转换的层应用不必要的置换。

最佳情况是,卷积层之间的层是 channels last 兼容的。输入张量在遇到第一个卷积层时会被排列成 channels last,并保持在该内存格式。因此,后续的卷积层将不需要排列其输入张量。

如果 channels last 不兼容的层位于卷积层之间,我们需要将输入张量排列回连续格式(contiguous format)以供该层使用。输入张量将以连续格式通过剩余的层,并在遇到另一个卷积层时被排列成 channels last。将该排列传播到更早的层是没有意义的,因为大多数层对 memory_format 都相当不敏感。

当 PyTorch 支持排列的融合时,这个说法可能会改变,因为可能存在比卷积层前立即融合排列更好的位置。

参数
  • module (nn.Module) – nn.Conv3d & nn.ConvTranspose3d 或容器 nn.Module

  • memory_format (memory_format) – 用户指定的 memory_format,例如 torch.channels_lasttorch.contiguous_format

返回

具有已更新 nn.Conv3d 的原始模块

返回类型

_M

示例

>>> input = torch.randint(
...     1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda"
... )
>>> model = nn.Sequential(
>>>     nn.Conv3d(8, 4, 3)).cuda().half()
>>> # This is identical to:
>>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
>>> model = nn.utils.convert_conv3d_weight_memory_format(
...     model, torch.channels_last_3d
... )
>>> out = model(input)