评价此页

理解基于 TorchDynamo 的 ONNX 导出器内存使用情况#

创建时间:2024年11月06日 | 最后更新时间:2025年06月18日

以前的基于 TorchScript 的 ONNX 导出器会执行模型一次以追踪其执行,这可能导致在 GPU 内存不足的情况下模型耗尽 GPU 内存。这个问题已通过新的基于 TorchDynamo 的 ONNX 导出器得到解决。

基于 TorchDynamo 的 ONNX 导出器利用 torch.export.export() 函数来利用 FakeTensorMode,从而避免在导出过程中执行实际的张量计算。与基于 TorchScript 的 ONNX 导出器相比,这种方法可以显著降低内存使用量。

下面是一个示例,展示了基于 TorchScript 和基于 TorchDynamo 的 ONNX 导出器之间的内存使用差异。在此示例中,我们使用了 MONAI 的 HighResNet 模型。在继续之前,请从 PyPI 安装它。

pip install monai

PyTorch 提供了一个捕获和可视化内存使用轨迹的工具。我们将使用此工具在导出过程中记录两个导出器的内存使用情况并进行比较。您可以在 理解 CUDA 内存使用情况 上找到有关此工具的更多详细信息。

基于 TorchScript 的导出器#

可以运行以下代码来生成一个快照文件,该文件记录了导出过程中分配的 CUDA 内存状态。

import torch

from monai.networks.nets import (
    HighResNet,
)

torch.cuda.memory._record_memory_history()

model = HighResNet(
    spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
).eval()

model = model.to("cuda")
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")

with torch.no_grad():
    onnx_program = torch.onnx.export(
        model,
        data,
        "torchscript_exporter_highresnet.onnx",
        dynamo=False,
    )

snapshot_name = "torchscript_exporter_example.pickle"
print(f"generate {snapshot_name}")

torch.cuda.memory._dump_snapshot(snapshot_name)
print("Export is done.")

打开 pytorch.org/memory_viz 并将生成的 pickle 快照文件拖放到可视化工具中。内存使用情况如下所示:

_images/torch_script_exporter_memory_usage.png

从图中可以看出,内存使用峰值高于 2.8 GB。

基于 TorchDynamo 的导出器#

可以运行以下代码来生成一个快照文件,该文件记录了导出过程中分配的 CUDA 内存状态。

import torch

from monai.networks.nets import (
    HighResNet,
)

torch.cuda.memory._record_memory_history()

model = HighResNet(
    spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
).eval()

model = model.to("cuda")
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")

with torch.no_grad():
    onnx_program = torch.onnx.export(
                        model,
                        data,
                        "test_faketensor.onnx",
                        dynamo=True,
                    )

snapshot_name = f"torchdynamo_exporter_example.pickle"
print(f"generate {snapshot_name}")

torch.cuda.memory._dump_snapshot(snapshot_name)
print(f"Export is done.")

打开 pytorch.org/memory_viz 并将生成的 pickle 快照文件拖放到可视化工具中。内存使用情况如下所示:

_images/torch_dynamo_exporter_memory_usage.png

从图中可以看出,内存使用峰值仅为 45MB 左右。与基于 TorchScript 的导出器的内存使用峰值相比,内存使用量减少了 98%。