快捷方式

get_graph_node_names

torchvision.models.feature_extraction.get_graph_node_names(model: Module, tracer_kwargs: Optional[dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[dict[str, Any]] = None) tuple[list[str], list[str]][源代码]

用于返回按执行顺序排列的节点的开发实用程序。有关节点名称的说明,请参阅 create_feature_extractor()。这对于查看可用于特征提取的节点名称非常有用。节点名称无法轻松从模型代码中直接读取有两个原因:

  1. 并非所有子模块都会被追踪。来自 torch.nn 的模块都属于此类。

  2. 表示相同操作或叶子模块重复应用的节点会获得一个 _{counter} 后缀。

模型会追踪两次:一次在训练模式下,一次在评估模式下。将返回两个节点的名称集。

有关此处使用的节点命名约定的更多详细信息,请参阅 相关小节,位于 文档 中。

参数:
  • model (nn.Module) – 我们想要打印节点名称的模型

  • tracer_kwargs (dict, optional) – 用于 NodePathTracer 的关键字参数字典(最终会传递给 torch.fx.Tracer)。默认情况下,它将包装并使所有 torchvision 操作成为叶子节点:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果用户提供 tracer_kwargs,上述默认参数将附加到用户提供的字典中。

  • suppress_diff_warning (bool, optional) – 当训练和评估版本图存在差异时,是否抑制警告。默认为 False。

  • concrete_args (Optional[Dict[str, any]]) – 不应被视为代理的具体参数。根据 Pytorch 文档,此参数的 API 可能不被保证。

返回:

一个在训练模式下追踪模型得到的节点名称列表,以及一个在评估模式下追踪模型得到的节点名称列表。

返回类型:

tuple(list, list)

示例

>>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model)

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源