评价此页

torch.utils.module_tracker#

创建于: May 04, 2024 | 最后更新于: Jun 11, 2025

此工具可用于跟踪 torch.nn.Module 层级结构中的当前位置。它可用于其他跟踪工具中,以便轻松地将测量到的量与用户友好的名称关联起来。目前,FlopCounterMode 特别使用了它。

class torch.utils.module_tracker.ModuleTracker[源代码]#

ModuleTracker 是一个上下文管理器,可在执行期间跟踪 nn.Module 层级结构,以便其他系统可以查询当前正在执行(或正在执行其反向传播)的 Module。

您可以通过此上下文管理器访问 parents 属性,以获取当前通过其 fqn(完全限定名,也用作 state_dict 中的键)执行的所有 Module 的集合。您可以访问 is_bw 属性来了解您当前是否正在反向传播。

请注意,parents 永远不会为空,并且始终包含“Global”键。在正向传播完成后,is_bw 标志将保持 True,直到执行另一个 Module。如果您需要更准确的信息,请提交一个 issue 请求此功能。添加从 fqn 到 module 实例的映射是可能的,但尚未实现,如果您需要此功能,请提交一个 issue 请求。

使用示例

mod = torch.nn.Linear(2, 2)

with ModuleTracker() as tracker:
    # Access anything during the forward pass
    def my_linear(m1, m2, bias):
        print(f"Current modules: {tracker.parents}")
        return torch.mm(m1, m2.t()) + bias

    torch.nn.functional.linear = my_linear

    mod(torch.rand(2, 2))