torch.utils.module_tracker#
创建于:2024 年 5 月 4 日 | 最后更新于:2025 年 6 月 11 日
此实用工具可用于跟踪 torch.nn.Module
层级结构中的当前位置。它可用于其他跟踪工具中,以便轻松地将测量到的数量与用户友好的名称关联起来。例如,FlopCounterMode 目前就使用了此工具。
- class torch.utils.module_tracker.ModuleTracker[source]#
ModuleTracker
是一个上下文管理器,可在执行期间跟踪 nn.Module 层级结构,以便其他系统可以查询当前正在执行哪个 Module(或其反向传播正在执行)。您可以通过此上下文管理器访问
parents
属性,以获取当前通过其 fqn(完全限定名,也用作 state_dict 中的键)执行的所有 Module 的集合。您可以通过访问is_bw
属性来了解您当前是否正在执行反向传播。请注意,
parents
永不为空,并且始终包含“Global”键。is_bw
标志将在前向传播完成后保持True
,直到执行另一个 Module。如果您需要更精确,请提交一个问题请求此功能。添加一个从 fqn 到 module 实例的映射是可能的,但尚未实现,如果您需要,请提交一个问题请求此功能。使用示例
mod = torch.nn.Linear(2, 2) with ModuleTracker() as tracker: # Access anything during the forward pass def my_linear(m1, m2, bias): print(f"Current modules: {tracker.parents}") return torch.mm(m1, m2.t()) + bias torch.nn.functional.linear = my_linear mod(torch.rand(2, 2))