torch.hub#
创建于: 2025年6月13日 | 最后更新于: 2025年6月13日
Pytorch Hub 是一个预训练模型仓库,旨在促进研究的可复现性。
发布模型#
Pytorch Hub 支持通过添加一个简单的 hubconf.py 文件,将预训练模型(模型定义和预训练权重)发布到 GitHub 仓库;
hubconf.py 可以有多个入口点。每个入口点都被定义为一个 Python 函数(例如:您想要发布的预训练模型)。
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
如何实现一个入口点?#
以下是一个代码片段,用于指定 resnet18 模型在 pytorch/vision/hubconf.py 中的入口点实现。在大多数情况下,导入正确的函数就足够了。这里我们仅使用展开的版本作为示例来展示其工作原理。您可以在 pytorch/vision repo 中找到完整的脚本。
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies变量是 **加载** 模型所需的包名称的 **列表**。请注意,这可能与训练模型所需的依赖项略有不同。args和kwargs会被传递给实际的可调用函数。函数的文档字符串用作帮助消息。它解释了模型的作用以及允许的位置/关键字参数。强烈建议在此处添加一些示例。
入口点函数可以返回一个模型(nn.module),或者辅助工具以使开发者的工作流程更顺畅,例如 tokenizers。
以下划线开头的可调用函数被视为辅助函数,它们不会出现在
torch.hub.list()中。预训练权重可以存储在 GitHub 仓库中,也可以通过
torch.hub.load_state_dict_from_url()加载。如果文件小于 2GB,建议将其附加到 项目发布 中,并使用发布中的 URL。在上面的示例中,torchvision.models.resnet.resnet18处理了pretrained参数,您也可以将以下逻辑放入入口点定义中。
if pretrained:
# For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
重要提示#
发布的模型应至少在分支/标签上。不能是任意提交。
从 Hub 加载模型#
Pytorch Hub 提供了便捷的 API,通过 torch.hub.list() 探索 hub 中所有可用的模型,通过 torch.hub.help() 显示文档字符串和示例,并通过 torch.hub.load() 加载预训练模型。
- torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[source]#
列出由
github指定的仓库中所有可用的入口点。- 参数:
github (str) – 格式为“repo_owner/repo_name[:ref]”的字符串,其中 :ref 是可选的(标签或分支)。如果未指定
ref,则默认分支假定为main(如果存在),否则为master。例如:‘pytorch/vision:0.10’force_reload (bool, optional) – 是否丢弃现有缓存并强制重新下载。默认为
False。skip_validation (bool, optional) – 如果为
False,torchhub 将检查github参数指定的ref是否正确属于仓库所有者。这将向 GitHub API 发送请求;您可以通过设置GITHUB_TOKEN环境变量来指定非默认的 GitHub token。默认为False。trust_repo (bool, str or None) –
"check",True,False或None。此参数在 v1.12 中引入,有助于确保用户仅运行受信任仓库的代码。如果为
False,将提示用户是否信任该仓库。如果为
True,该仓库将被添加到信任列表,并在无需显式确认的情况下加载。如果为
"check",该仓库将在缓存的信任仓库列表中进行检查。如果不在该列表中,行为将回退到trust_repo=False选项。如果为
None:这将发出警告,邀请用户将trust_repo设置为False,True或"check"。这仅为向后兼容性而存在,将在 v2.0 中移除。
默认为
None,在 v2.0 中将更改为"check"。verbose (bool, optional) – 如果为
False,则静默关于命中本地缓存的消息。请注意,关于首次下载的消息无法静默。默认为True。
- 返回:
可用的可调用入口点
- 返回类型:
示例
>>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
- torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[source]#
显示入口点
model的文档字符串。- 参数:
github (str) – 格式为 <repo_owner/repo_name[:ref]> 的字符串,其中 :ref 是可选的(一个标签或一个分支)。如果未指定
ref,则默认分支假定为main(如果存在),否则为master。例如:‘pytorch/vision:0.10’model (str) – 仓库
hubconf.py中定义的入口点名称。force_reload (bool, optional) – 是否丢弃现有缓存并强制重新下载。默认为
False。skip_validation (bool, optional) – 如果为
False,torchhub 将检查github参数指定的ref是否正确属于仓库所有者。这将向 GitHub API 发送请求;您可以通过设置GITHUB_TOKEN环境变量来指定非默认的 GitHub token。默认为False。trust_repo (bool, str or None) –
"check",True,False或None。此参数在 v1.12 中引入,有助于确保用户仅运行受信任仓库的代码。如果为
False,将提示用户是否信任该仓库。如果为
True,该仓库将被添加到信任列表,并在无需显式确认的情况下加载。如果为
"check",该仓库将在缓存的信任仓库列表中进行检查。如果不在该列表中,行为将回退到trust_repo=False选项。如果为
None:这将发出警告,邀请用户将trust_repo设置为False,True或"check"。这仅为向后兼容性而存在,将在 v2.0 中移除。
默认为
None,在 v2.0 中将更改为"check"。
示例
>>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
- torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[source]#
从 GitHub 仓库或本地目录加载模型。
注意:加载模型是典型的用例,但它也可以用于加载其他对象,如 tokenizers、损失函数等。
如果
source是 ‘github’,则repo_or_dir应该具有repo_owner/repo_name[:ref]的形式,其中 :ref 是可选的(一个标签或一个分支)。如果
source是 ‘local’,则repo_or_dir应该是一个本地目录的路径。- 参数:
repo_or_dir (str) – 如果
source是 ‘github’,这应该对应一个格式为repo_owner/repo_name[:ref]的 GitHub 仓库,其中 :ref 是可选的(标签或分支),例如 ‘pytorch/vision:0.10’。如果未指定ref,则默认分支假定为main(如果存在),否则为master。如果source是 ‘local’,则它应该是本地目录的路径。model (str) – 仓库/目录的
hubconf.py中定义的某个可调用项(入口点)的名称。*args (optional) – 可调用项
model的相应参数。source (str, optional) – ‘github’ 或 ‘local’。指定如何解释
repo_or_dir。默认为 ‘github’。trust_repo (bool, str or None) –
"check",True,False或None。此参数在 v1.12 中引入,有助于确保用户仅运行受信任仓库的代码。如果为
False,将提示用户是否信任该仓库。如果为
True,该仓库将被添加到信任列表,并在无需显式确认的情况下加载。如果为
"check",该仓库将在缓存的信任仓库列表中进行检查。如果不在该列表中,行为将回退到trust_repo=False选项。如果为
None:这将发出警告,邀请用户将trust_repo设置为False,True或"check"。这仅为向后兼容性而存在,将在 v2.0 中移除。
默认为
None,在 v2.0 中将更改为"check"。force_reload (bool, optional) – 是否无条件地强制重新下载 GitHub 仓库。如果
source = 'local',则无效。默认为False。verbose (bool, optional) – 如果为
False,则静默关于命中本地缓存的消息。请注意,关于首次下载的消息无法静默。如果source = 'local',则无效。默认为True。skip_validation (bool, optional) – 如果为
False,torchhub 将检查github参数指定的ref是否正确属于仓库所有者。这将向 GitHub API 发送请求;您可以通过设置GITHUB_TOKEN环境变量来指定非默认的 GitHub token。默认为False。**kwargs (optional) – 可调用项
model的相应关键字参数。
- 返回:
使用给定的
*args和**kwargs调用model可调用项的输出。
示例
>>> # from a github repo >>> repo = "pytorch/vision" >>> model = torch.hub.load( ... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1" ... ) >>> # from a local directory >>> path = "/some/local/path/pytorch/vision" >>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
- torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[source]#
将给定 URL 的对象下载到本地路径。
- 参数:
示例
>>> torch.hub.download_url_to_file( ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth", ... "/tmp/temporary_file", ... )
- torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[source]#
加载给定 URL 的 Torch 序列化对象。
如果下载的文件是 zip 文件,它将被自动解压。
如果对象已存在于 model_dir 中,它将被反序列化并返回。
model_dir的默认值为<hub_dir>/checkpoints,其中hub_dir是get_dir()返回的目录。- 参数:
url (str) – 要下载的对象的 URL
model_dir (str, optional) – 用于保存对象的目录
map_location (optional) – 一个函数或字典,指定如何重新映射存储位置(参见 torch.load)
progress (bool, optional) – 是否向 stderr 显示进度条。默认值:True
check_hash (bool, optional) – 如果为 True,URL 的文件名部分应遵循命名约定
filename-<sha256>.ext,其中<sha256>是文件内容 SHA256 哈希的前八位或更多数字。哈希用于确保名称的唯一性并验证文件内容。默认值:Falsefile_name (str, optional) – 下载文件的名称。如果未设置,将使用 URL 中的文件名。
weights_only (bool, optional) – 如果为 True,则只加载权重,不加载复杂的 pickled 对象。推荐用于不可信的来源。有关更多详细信息,请参阅
load()。
- 返回类型:
示例
>>> state_dict = torch.hub.load_state_dict_from_url( ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth" ... )
运行加载的模型:#
请注意,torch.hub.load() 中的 *args 和 **kwargs 用于 **实例化** 模型。加载模型后,如何了解可以对模型做什么?一个建议的工作流程是:
dir(model)查看模型的所有可用方法。help(model.foo)检查model.foo需要哪些参数才能运行。
为了帮助用户在不来回查阅文档的情况下进行探索,我们强烈建议仓库所有者使函数帮助消息清晰简洁。包含一个最小工作示例也很有帮助。
我的下载模型保存在哪里?#
位置按以下顺序使用:
调用
hub.set_dir(<PATH_TO_HUB_DIR>)$TORCH_HOME/hub,如果设置了环境变量TORCH_HOME。$XDG_CACHE_HOME/torch/hub,如果设置了环境变量XDG_CACHE_HOME。~/.cache/torch/hub
缓存逻辑#
默认情况下,我们不会在加载文件后清理它们。Hub 默认使用缓存,如果它已存在于 get_dir() 返回的目录中。
用户可以通过调用 hub.load(..., force_reload=True) 来强制重新加载。这将删除现有的 GitHub 文件夹和下载的权重,并重新初始化一次新的下载。当更新发布到同一分支时,这很有用,用户可以跟上最新版本。
已知限制:#
Torch hub 的工作方式是像导入已安装的包一样导入。在 Python 中导入会引入一些副作用。例如,您会在 Python 缓存 sys.modules 和 sys.path_importer_cache 中看到新项目,这是正常的 Python 行为。这也意味着,如果不同的仓库具有相同的子包名称(通常是 model 子包),则在从不同仓库导入不同模型时可能会遇到导入错误。一种解决方法是删除 sys.modules 字典中的有问题的子包;有关更多详细信息,请参阅 此 GitHub issue。
一个值得在此提及的已知限制是:用户 **不能** 在 **同一个 Python 进程** 中加载同一仓库的两个不同分支。这就像在 Python 中安装两个同名包一样,这是不好的。如果您尝试这样做,缓存可能会出现问题并给您带来惊喜。当然,在不同的进程中加载它们是完全没问题的。