评价此页

torch.hub#

创建于: 2025年6月13日 | 最后更新于: 2025年6月13日

Pytorch Hub 是一个预训练模型库,旨在促进研究的可复现性。

发布模型#

Pytorch Hub 支持通过添加一个简单的 hubconf.py 文件将预训练模型(模型定义和预训练权重)发布到 GitHub 仓库;

hubconf.py 可以有多个入口点。每个入口点都定义为一个 Python 函数(例如:您想要发布的预训练模型)。

  def entrypoint_name(*args, **kwargs):
      # args & kwargs are optional, for models which take positional/keyword arguments.
      ...

如何实现入口点?#

以下是为 resnet18 模型指定入口点的代码片段,如果我们扩展 pytorch/vision/hubconf.py 中的实现。在大多数情况下,在 hubconf.py 中导入正确的函数就足够了。这里我们只是想使用扩展版本作为示例来展示它的工作原理。您可以在 pytorch/vision 仓库中查看完整脚本。

  dependencies = ['torch']
  from torchvision.models.resnet import resnet18 as _resnet18

  # resnet18 is the name of entrypoint
  def resnet18(pretrained=False, **kwargs):
      """ # This docstring shows up in hub.help()
      Resnet18 model
      pretrained (bool): kwargs, load pretrained weights into the model
      """
      # Call the model, load pretrained weights
      model = _resnet18(pretrained=pretrained, **kwargs)
      return model
  • dependencies 变量是一个包名**列表**,加载模型所需。请注意,这可能与训练模型所需的依赖项略有不同。

  • argskwargs 将传递给实际的可调用函数。

  • 函数的 Docstring 作为帮助消息。它解释了模型的功能以及允许的位置/关键字参数。强烈建议在此处添加一些示例。

  • 入口点函数可以返回模型(nn.module),也可以返回辅助工具以使用户工作流程更顺畅,例如分词器。

  • 以_开头的可调用函数被视为辅助函数,不会出现在 torch.hub.list() 中。

  • 预训练权重可以存储在 GitHub 仓库本地,也可以通过 torch.hub.load_state_dict_from_url() 加载。如果小于 2GB,建议将其附加到项目发布并使用发布中的 URL。在上面的示例中,torchvision.models.resnet.resnet18 处理 pretrained,或者您可以将以下逻辑放入入口点定义中。

  if pretrained:
      # For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
      dirname = os.path.dirname(__file__)
      checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
      state_dict = torch.load(checkpoint)
      model.load_state_dict(state_dict)

      # For checkpoint saved elsewhere
      checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
      model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

重要提示#

  • 发布的模型至少应位于分支/标签中。不能是随机提交。

从 Hub 加载模型#

PyTorch Hub 提供了便捷的 API,可以通过 torch.hub.list() 探索 Hub 中所有可用的模型,通过 torch.hub.help() 查看文档字符串和示例,并使用 torch.hub.load() 加载预训练模型。

torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[source]#

列出由 github 指定的仓库中所有可调用的入口点。

参数
  • github (str) – 格式为 “repo_owner/repo_name[:ref]” 的字符串,其中 ref (标签或分支) 是可选的。如果未指定 ref,则默认分支假定为 main(如果存在),否则为 master。示例:'pytorch/vision:0.10'

  • force_reload (bool, 可选) – 是否丢弃现有缓存并强制重新下载。默认为 False

  • skip_validation (bool, 可选) – 如果为 False,torchhub 将检查 github 参数指定的引用是否正确属于仓库所有者。这将向 GitHub API 发送请求;您可以通过设置 GITHUB_TOKEN 环境变量来指定非默认的 GitHub 令牌。默认为 False

  • trust_repo (bool, strNone) –

    "check"TrueFalseNone。此参数在 v1.12 中引入,有助于确保用户只运行他们信任的仓库中的代码。

    • 如果为 False,将提示用户是否应信任此仓库。

    • 如果为 True,该仓库将被添加到受信任列表,无需明确确认即可加载。

    • 如果为 "check",该仓库将对照缓存中的受信任仓库列表进行检查。如果不在该列表中,行为将回退到 trust_repo=False 选项。

    • 如果为 None:这将引发警告,提示用户将 trust_repo 设置为 FalseTrue"check"。这仅用于向后兼容,并将在 v2.0 中移除。

    默认为 None,并最终在 v2.0 中更改为 "check"

  • verbose (bool, 可选) – 如果为 False,则静默关于命中本地缓存的消息。请注意,首次下载的消息无法静默。默认为 True

返回

可用的可调用入口点

返回类型

列表

示例

>>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[source]#

显示入口点 model 的文档字符串。

参数
  • github (str) – 格式为 <repo_owner/repo_name[:ref]> 的字符串,其中 ref (标签或分支) 是可选的。如果未指定 ref,则默认分支假定为 main(如果存在),否则为 master。示例:'pytorch/vision:0.10'

  • model (str) – 仓库的 hubconf.py 中定义的入口点名称字符串

  • force_reload (bool, 可选) – 是否丢弃现有缓存并强制重新下载。默认为 False

  • skip_validation (bool, 可选) – 如果为 False,torchhub 将检查 github 参数指定的引用是否正确属于仓库所有者。这将向 GitHub API 发送请求;您可以通过设置 GITHUB_TOKEN 环境变量来指定非默认的 GitHub 令牌。默认为 False

  • trust_repo (bool, strNone) –

    "check"TrueFalseNone。此参数在 v1.12 中引入,有助于确保用户只运行他们信任的仓库中的代码。

    • 如果为 False,将提示用户是否应信任此仓库。

    • 如果为 True,该仓库将被添加到受信任列表,无需明确确认即可加载。

    • 如果为 "check",该仓库将对照缓存中的受信任仓库列表进行检查。如果不在该列表中,行为将回退到 trust_repo=False 选项。

    • 如果为 None:这将引发警告,提示用户将 trust_repo 设置为 FalseTrue"check"。这仅用于向后兼容,并将在 v2.0 中移除。

    默认为 None,并最终在 v2.0 中更改为 "check"

示例

>>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[source]#

从 GitHub 仓库或本地目录加载模型。

注意:加载模型是典型的用例,但此功能也可用于加载其他对象,例如分词器、损失函数等。

如果 source 为 'github',则 repo_or_dir 预期格式为 repo_owner/repo_name[:ref],其中 ref (标签或分支) 是可选的。

如果 source 为 'local',则 repo_or_dir 预期为本地目录的路径。

参数
  • repo_or_dir (str) – 如果 source 为 'github',这应对应于 GitHub 仓库,格式为 repo_owner/repo_name[:ref],其中 ref (标签或分支) 是可选的,例如 'pytorch/vision:0.10'。如果未指定 ref,则默认分支假定为 main(如果存在),否则为 master。如果 source 为 'local',则应为本地目录的路径。

  • model (str) – 在仓库/目录的 hubconf.py 中定义的可调用(入口点)的名称。

  • *args (可选) – 可调用 model 对应的参数。

  • source (str, 可选) – 'github' 或 'local'。指定如何解释 repo_or_dir。默认为 'github'。

  • trust_repo (bool, strNone) –

    "check"TrueFalseNone。此参数在 v1.12 中引入,有助于确保用户只运行他们信任的仓库中的代码。

    • 如果为 False,将提示用户是否应信任此仓库。

    • 如果为 True,该仓库将被添加到受信任列表,无需明确确认即可加载。

    • 如果为 "check",该仓库将对照缓存中的受信任仓库列表进行检查。如果不在该列表中,行为将回退到 trust_repo=False 选项。

    • 如果为 None:这将引发警告,提示用户将 trust_repo 设置为 FalseTrue"check"。这仅用于向后兼容,并将在 v2.0 中移除。

    默认为 None,并最终在 v2.0 中更改为 "check"

  • force_reload (bool, 可选) – 是否强制无条件重新下载 GitHub 仓库。如果 source = 'local' 则无效。默认为 False

  • verbose (bool, 可选) – 如果为 False,则静默关于命中本地缓存的消息。请注意,首次下载的消息无法静默。如果 source = 'local' 则无效。默认为 True

  • skip_validation (bool, 可选) – 如果为 False,torchhub 将检查 github 参数指定的引用是否正确属于仓库所有者。这将向 GitHub API 发送请求;您可以通过设置 GITHUB_TOKEN 环境变量来指定非默认的 GitHub 令牌。默认为 False

  • **kwargs (可选) – 可调用 model 对应的关键字参数。

返回

使用给定 *args**kwargs 调用 model 可调用对象时的输出。

示例

>>> # from a github repo
>>> repo = "pytorch/vision"
>>> model = torch.hub.load(
...     repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1"
... )
>>> # from a local directory
>>> path = "/some/local/path/pytorch/vision"
>>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[source]#

将给定 URL 处的对象下载到本地路径。

参数
  • url (str) – 要下载的对象的 URL

  • dst (str) – 对象将保存的完整路径,例如 /tmp/temporary_file

  • hash_prefix (str, 可选) – 如果不为 None,下载文件的 SHA256 应以 hash_prefix 开头。默认值:None

  • progress (bool, 可选) – 是否向 stderr 显示进度条。默认值:True

示例

>>> torch.hub.download_url_to_file(
...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth",
...     "/tmp/temporary_file",
... )
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[source]#

加载给定 URL 处的 Torch 序列化对象。

如果下载的文件是 zip 文件,它将自动解压缩。

如果对象已存在于 model_dir 中,则会反序列化并返回。 model_dir 的默认值为 <hub_dir>/checkpoints,其中 hub_dir 是由 get_dir() 返回的目录。

参数
  • url (str) – 要下载的对象的 URL

  • model_dir (str, 可选) – 保存对象的目录

  • map_location (可选) – 一个函数或字典,指定如何重新映射存储位置(参见 torch.load)

  • progress (bool, 可选) – 是否向 stderr 显示进度条。默认值:True

  • check_hash (bool, 可选) – 如果为 True,则 URL 的文件名部分应遵循命名约定 filename-<sha256>.ext,其中 <sha256> 是文件内容的 SHA256 哈希的前八位或更多数字。哈希用于确保唯一名称并验证文件内容。默认值:False

  • file_name (str, 可选) – 下载文件的名称。如果未设置,将使用 url 中的文件名。

  • weights_only (bool, 可选) – 如果为 True,则仅加载权重,不加载复杂的 pickled 对象。建议用于不受信任的源。有关更多详细信息,请参见 load()

返回类型

dict[str, Any]

示例

>>> state_dict = torch.hub.load_state_dict_from_url(
...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth"
... )

运行已加载的模型:#

请注意,torch.hub.load() 中的 *args**kwargs 用于实例化模型。加载模型后,如何知道模型可以执行哪些操作?建议的工作流程是:

  • dir(model) 查看模型所有可用的方法。

  • help(model.foo) 检查 model.foo 运行时需要哪些参数。

为帮助用户无需反复查阅文档即可探索,我们强烈建议仓库所有者使函数帮助信息清晰简洁。包含一个最小的工作示例也很有帮助。

我的下载模型保存在哪里?#

位置按以下顺序使用:

  • 调用 hub.set_dir(<PATH_TO_HUB_DIR>)

  • $TORCH_HOME/hub,如果设置了环境变量 TORCH_HOME

  • $XDG_CACHE_HOME/torch/hub,如果设置了环境变量 XDG_CACHE_HOME

  • ~/.cache/torch/hub

torch.hub.get_dir()[source]#

获取用于存储下载模型和权重的 Torch Hub 缓存目录。

如果未调用 set_dir(),则默认路径为 $TORCH_HOME/hub,其中环境变量 $TORCH_HOME 默认为 $XDG_CACHE_HOME/torch$XDG_CACHE_HOME 遵循 Linux 文件系统布局的 X Design Group 规范,如果未设置环境变量,则默认值为 ~/.cache

返回类型

str

torch.hub.set_dir(d)[source]#

可选地设置用于保存下载模型和权重的 Torch Hub 目录。

参数

d (str) – 用于保存下载模型和权重的本地文件夹路径。

缓存逻辑#

默认情况下,加载文件后我们不会清理文件。如果 Hub 目录中已经存在,Hub 默认使用由 get_dir() 返回的缓存。

用户可以通过调用 hub.load(..., force_reload=True) 强制重新加载。这将删除现有的 GitHub 文件夹和已下载的权重,并重新初始化新的下载。当更新发布到同一分支时,这对于用户保持最新版本很有用。

已知限制:#

Torch hub 的工作原理是像安装包一样导入包。Python 中的导入会带来一些副作用。例如,您可能会在 Python 缓存 sys.modulessys.path_importer_cache 中看到新项,这是正常的 Python 行为。这也意味着,如果不同的仓库具有相同的子包名称(通常是 model 子包),则在从不同仓库导入不同模型时可能会遇到导入错误。解决此类导入错误的一种方法是从 sys.modules 字典中删除有问题的子包;更多详细信息可在此 GitHub 问题中找到。

一个值得一提的已知限制是:用户不能同一个 Python 进程中加载同一仓库的两个不同分支。这就像在 Python 中安装两个同名包一样,这是不好的。如果你真的尝试这样做,缓存可能会参与进来并给你带来惊喜。当然,在单独的进程中加载它们是完全没问题的。