AveragedModel#
- class torch.optim.swa_utils.AveragedModel(model, device=None, avg_fn=None, multi_avg_fn=None, use_buffers=False)[source]#
实现了用于随机权重平均(SWA)和指数移动平均(EMA)的平均模型。
随机权重平均由 Pavel Izmailov、Dmitrii Podoprikhin、Timur Garipov、Dmitry Vetrov 和 Andrew Gordon Wilson 在 Averaging Weights Leads to Wider Optima and Better Generalization (UAI 2018) 中提出。
指数移动平均是 Polyak averaging 的一种变体,但它使用指数权重而不是迭代中的相等权重。
AveragedModel 类在
device
设备上创建提供的模块model
的副本,并允许计算model
参数的运行平均值。- 参数
model (torch.nn.Module) – 用于 SWA/EMA 的模型
device (torch.device, optional) – 如果提供,则平均模型将存储在
device
上avg_fn (function, optional) – 用于更新参数的平均函数;该函数必须接受
AveragedModel
参数的当前值、model
参数的当前值以及已平均的模型数量;如果为 None,则使用相等权重的平均值(默认:None)multi_avg_fn (function, optional) – 用于原地更新参数的平均函数;该函数必须接受列表形式的
AveragedModel
参数的当前值、列表形式的model
参数的当前值以及已平均的模型数量;如果为 None,则使用相等权重的平均值(默认:None)use_buffers (bool) – 如果为
True
,则会同时计算模型参数和缓冲区的运行平均值。(默认:False
)
示例
>>> loader, optimizer, model, loss_fn = ... >>> swa_model = torch.optim.swa_utils.AveragedModel(model) >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, >>> T_max=300) >>> swa_start = 160 >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05) >>> for i in range(300): >>> for input, target in loader: >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() >>> if i > swa_start: >>> swa_model.update_parameters(model) >>> swa_scheduler.step() >>> else: >>> scheduler.step() >>> >>> # Update bn statistics for the swa_model at the end >>> torch.optim.swa_utils.update_bn(loader, swa_model)
您也可以使用 avg_fn 或 multi_avg_fn 参数自定义平均函数。如果未提供平均函数,则默认值为计算权重的平均值(SWA)。
示例
>>> # Compute exponential moving averages of the weights and buffers >>> ema_model = torch.optim.swa_utils.AveragedModel(model, >>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
注意
在使用包含 Batch Normalization 的模型进行 SWA/EMA 时,您可能需要更新 Batch Normalization 的激活统计信息。这可以通过使用
torch.optim.swa_utils.update_bn()
或将use_buffers
设置为 True 来完成。第一种方法通过让数据通过模型来更新训练后的统计信息。第二种方法通过平均所有缓冲区在参数更新阶段进行更新。经验证据表明,更新归一化层中的统计信息可以提高准确性,但您可能需要通过实验来确定哪种方法在您的项目中能产生最佳结果。注意
avg_fn
和 multi_avg_fn 不会保存在模型的state_dict()
中。注意
当首次调用
update_parameters()
时(即n_averaged
为 0),model 的参数会被复制到AveragedModel
的参数中。之后每次调用update_parameters()
时,都会使用 avg_fn 函数来更新参数。- apply(fn)[source]#
将
fn
递归应用于每个子模块(由.children()
返回)以及自身。典型用途包括初始化模型的参数(另请参阅 torch.nn.init)。
- 参数
fn (
Module
-> None) – 要应用于每个子模块的函数- 返回
self
- 返回类型
示例
>>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )
- buffers(recurse=True)[source]#
返回模块缓冲区的迭代器。
- 参数
recurse (bool) – 如果为 True,则产生此模块及其所有子模块的缓冲区。否则,只产生作为此模块直接成员的缓冲区。
- 生成
torch.Tensor – 模块缓冲区
- 返回类型
示例
>>> for buf in model.buffers(): >>> print(type(buf), buf.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- compile(*args, **kwargs)[source]#
使用
torch.compile()
编译此模块的 forward。此模块的 __call__ 方法已编译,所有参数按原样传递给
torch.compile()
。有关此函数参数的详细信息,请参阅
torch.compile()
。
- cuda(device=None)[源]#
将所有模型参数和缓冲区移动到 GPU。
这也会使相关的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 GPU 上,则应在构建优化器之前调用此函数。
注意
此方法就地修改模块。
- eval()[源]#
将模块设置为评估模式。
此设置仅对某些模块有影响。请参阅特定模块的文档,了解它们在训练/评估模式下的行为,例如它们是否受影响,例如
Dropout
、BatchNorm
等。此操作等同于
self.train(False)
。请参阅 局部禁用梯度计算,了解 .eval() 与几种可能与之混淆的类似机制的比较。
- 返回
self
- 返回类型
- get_buffer(target)[源]#
返回由
target
给定的缓冲区(如果存在),否则抛出错误。有关此方法功能的更详细解释以及如何正确指定
target
,请参阅get_submodule
的文档字符串。- 参数
target (str) – 要查找的缓冲区的完全限定字符串名称。(请参阅
get_submodule
以了解如何指定完全限定字符串。)- 返回
由
target
引用的缓冲区- 返回类型
- 引发
AttributeError – 如果目标字符串引用了无效路径或解析为非缓冲区对象。
- get_extra_state()[源]#
返回要包含在模块 state_dict 中的任何额外状态。
如果需要存储额外状态,请实现此方法以及相应的
set_extra_state()
。调用此函数时会构建模块的 state_dict()。请注意,额外状态应该是可序列化的,以确保 state_dict 的序列化有效。我们仅为序列化 Tensor 提供向后兼容性保证;其他对象的序列化其固定形式如果改变,可能会破坏向后兼容性。
- 返回
要存储在模块 state_dict 中的任何额外状态
- 返回类型
- get_parameter(target)[源]#
如果存在,返回由
target
给定的参数,否则抛出错误。有关此方法功能的更详细解释以及如何正确指定
target
,请参阅get_submodule
的文档字符串。- 参数
target (str) – 要查找的参数的完全限定字符串名称。(请参阅
get_submodule
以了解如何指定完全限定字符串。)- 返回
由
target
引用的参数- 返回类型
torch.nn.Parameter
- 引发
AttributeError – 如果目标字符串引用了无效路径或解析为非
nn.Parameter
对象。
- get_submodule(target)[源]#
如果存在,返回由
target
给定的子模块,否则抛出错误。例如,假设您有一个
nn.Module
A
,它看起来像这样A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) )
(图示显示了一个
nn.Module
A
。A
包含一个嵌套的子模块net_b
,它本身有两个子模块net_c
和linear
。net_c
然后有一个子模块conv
。)为了检查我们是否拥有
linear
子模块,我们将调用get_submodule("net_b.linear")
。为了检查我们是否拥有conv
子模块,我们将调用get_submodule("net_b.net_c.conv")
。get_submodule
的运行时受到target
中模块嵌套深度的限制。对named_modules
的查询可以达到相同的结果,但它是相对于传递模块数量的 O(N) 操作。因此,对于检查某个子模块是否存在这样的简单检查,应始终使用get_submodule
。- 参数
target (str) – 要查找的子模块的完全限定字符串名称。(请参阅上面的示例以了解如何指定完全限定字符串。)
- 返回
由
target
引用的子模块- 返回类型
- 引发
AttributeError – 如果在目标字符串路径中的任何点,(子)路径解析为不存在的属性名称或不是
nn.Module
实例的对象。
- ipu(device=None)[源]#
将所有模型参数和缓冲区移动到 IPU。
这也会使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 IPU 上,则应在构建优化器之前调用它。
注意
此方法就地修改模块。
- load_state_dict(state_dict, strict=True, assign=False)[源]#
将
state_dict
中的参数和缓冲区复制到此模块及其后代中。如果
strict
为True
,则state_dict
的键必须与此模块的state_dict()
函数返回的键完全匹配。警告
如果
assign
为True
,则必须在调用load_state_dict
之后创建优化器,除非get_swap_module_params_on_conversion()
为True
。- 参数
state_dict (dict) – 包含参数和持久缓冲区的字典。
strict (bool, optional) – 是否严格强制
state_dict
中的键与此模块的state_dict()
函数返回的键匹配。默认为True
。assign (bool, optional) – 当设置为
False
时,将保留当前模块的张量属性;而设置为True
时,将保留 state dict 中张量的属性。唯一例外是Parameter
的requires_grad
字段,此时会保留来自模块的值。默认为False
。
- 返回
missing_keys
是一个包含此模块期望但在提供的
state_dict
中缺失的任何键的字符串列表。
unexpected_keys
是一个字符串列表,包含此模块不期望但在提供的
state_dict
中存在的键。
- 返回类型
具有
missing_keys
和unexpected_keys
字段的NamedTuple
。
注意
如果参数或缓冲区被注册为
None
且其对应的键存在于state_dict
中,load_state_dict()
将会引发RuntimeError
。
- modules()[源]#
返回网络中所有模块的迭代器。
注意
重复的模块只返回一次。在以下示例中,
l
只返回一次。示例
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True)
- mtia(device=None)[源]#
将所有模型参数和缓冲区移动到 MTIA。
这也会使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 MTIA 上,则应在构建优化器之前调用它。
注意
此方法就地修改模块。
- named_buffers(prefix='', recurse=True, remove_duplicate=True)[源]#
返回模块缓冲区上的迭代器,同时生成缓冲区的名称和缓冲区本身。
- 参数
- 生成
(str, torch.Tensor) – 包含名称和缓冲区的元组
- 返回类型
示例
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_children()[源]#
返回对直接子模块的迭代器,生成模块的名称和模块本身。
示例
>>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module)
- named_modules(memo=None, prefix='', remove_duplicate=True)[源]#
返回网络中所有模块的迭代器,同时生成模块的名称和模块本身。
- 参数
- 生成
(str, Module) – 名称和模块的元组
注意
重复的模块只返回一次。在以下示例中,
l
只返回一次。示例
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
- named_parameters(prefix='', recurse=True, remove_duplicate=True)[源]#
返回模块参数的迭代器,同时生成参数的名称和参数本身。
- 参数
- 生成
(str, Parameter) – 包含名称和参数的元组
- 返回类型
示例
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- parameters(recurse=True)[源]#
返回模块参数的迭代器。
这通常传递给优化器。
- 参数
recurse (bool) – 如果为 True,则会生成此模块和所有子模块的参数。否则,只生成此模块直接包含的参数。
- 生成
Parameter – 模块参数
- 返回类型
示例
>>> for param in model.parameters(): >>> print(type(param), param.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- register_backward_hook(hook)[源]#
在模块上注册一个反向传播钩子。
此函数已弃用,推荐使用
register_full_backward_hook()
,并且此函数在未来版本中的行为将发生更改。- 返回
一个句柄,可用于通过调用
handle.remove()
来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_buffer(name, tensor, persistent=True)[源]#
向模块添加一个缓冲区。
这通常用于注册一个不应被视为模型参数的缓冲区。例如,BatchNorm 的
running_mean
不是参数,但它是模块状态的一部分。缓冲区默认是持久化的,会与参数一起保存。此行为可以通过将persistent
设置为False
来更改。持久化缓冲区和非持久化缓冲区之间的唯一区别在于后者不会成为此模块state_dict
的一部分。可以使用给定名称作为属性访问缓冲区。
- 参数
name (str) – 缓冲区的名称。可以使用给定名称从此模块访问缓冲区
tensor (Tensor 或 None) – 要注册的缓冲区。如果为
None
,则跳过在缓冲区上运行的操作,如cuda
。如果为None
,则缓冲区 **不** 包含在模块的state_dict
中。persistent (bool) – 缓冲区是否是此模块
state_dict
的一部分。
示例
>>> self.register_buffer('running_mean', torch.zeros(num_features))
- register_forward_hook(hook, *, prepend=False, with_kwargs=False, always_call=False)[源]#
在模块上注册一个前向钩子。
在
forward()
计算完输出后,将为每个输出调用此 hook。如果
with_kwargs
为False
或未指定,则输入仅包含传递给模块的位置参数。关键字参数将不会传递给 hook,只会传递给forward
。hook 可以修改输出。它可以在原地修改输入,但不会影响前向传播,因为它是在forward()
调用之后调用的。hook 应该具有以下签名:hook(module, args, output) -> None or modified output
如果
with_kwargs
为True
,则 forward hook 将收到传递给 forward 函数的kwargs
,并需要返回可能已修改的输出。hook 应该具有以下签名:hook(module, args, kwargs, output) -> None or modified output
- 参数
hook (Callable) – 用户定义的待注册钩子。
prepend (bool) – 如果为
True
,则提供的hook
将在对此torch.nn.Module
的所有现有forward
hook 之前触发。否则,提供的hook
将在此torch.nn.Module
的所有现有forward
hook 之后触发。请注意,通过register_module_forward_hook()
注册的全局forward
hook 将在此方法注册的所有 hook 之前触发。默认为False
。with_kwargs (bool) – 如果为
True
,则hook
将收到传递给 forward 函数的 kwargs。默认为False
。always_call (bool) – 如果为
True
,则无论在调用 Module 时是否引发异常,都会运行hook
。默认为False
。
- 返回
一个句柄,可用于通过调用
handle.remove()
来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_forward_pre_hook(hook, *, prepend=False, with_kwargs=False)[源]#
在模块上注册一个前向预钩子。
在调用
forward()
之前,将为每个调用调用此 hook。如果
with_kwargs
为 false 或未指定,则输入仅包含传递给模块的位置参数。关键字参数将不会传递给 hook,只会传递给forward
。hook 可以修改输入。用户可以返回一个元组或 hook 中的单个修改值。如果返回单个值(除非该值已经是元组),我们将将其包装成元组。hook 应该具有以下签名:hook(module, args) -> None or modified input
如果
with_kwargs
为 true,则 forward pre-hook 将收到传递给 forward 函数的 kwargs。如果 hook 修改了输入,则应同时返回 args 和 kwargs。hook 应该具有以下签名:hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
- 参数
hook (Callable) – 用户定义的待注册钩子。
prepend (bool) – 如果为 true,则提供的
hook
将在此torch.nn.Module
的所有现有forward_pre
hook 之前触发。否则,提供的hook
将在此torch.nn.Module
的所有现有forward_pre
hook 之后触发。请注意,通过register_module_forward_pre_hook()
注册的全局forward_pre
hook 将在此方法注册的所有 hook 之前触发。默认为False
。with_kwargs (bool) – 如果为 true,则
hook
将收到传递给 forward 函数的 kwargs。默认为False
。
- 返回
一个句柄,可用于通过调用
handle.remove()
来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_full_backward_hook(hook, prepend=False)[源]#
在模块上注册一个反向传播钩子。
每次计算相对于模块的梯度时,将调用此钩子,其触发规则如下:
通常,钩子在计算相对于模块输入的梯度时触发。
如果模块输入都不需要梯度,则在计算相对于模块输出的梯度时触发钩子。
如果模块输出都不需要梯度,则钩子将不触发。
钩子应具有以下签名
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
grad_input
和grad_output
是包含相对于输入和输出的梯度的元组。hook 不应修改其参数,但可以选择性地返回一个相对于输入的、将用于替换grad_input
的新梯度。grad_input
将仅对应于作为位置参数给出的输入,并且所有关键字参数都将被忽略。grad_input
和grad_output
中的条目对于所有非 Tensor 参数将为None
。由于技术原因,当此钩子应用于模块时,其前向函数将接收传递给模块的每个张量的视图。类似地,调用者将接收模块前向函数返回的每个张量的视图。
警告
使用反向传播钩子时不允许就地修改输入或输出,否则将引发错误。
- 参数
hook (Callable) – 要注册的用户定义钩子。
prepend (bool) – 如果为 true,则提供的
hook
将在此torch.nn.Module
的所有现有backward
hook 之前触发。否则,提供的hook
将在此torch.nn.Module
的所有现有backward
hook 之后触发。请注意,通过register_module_full_backward_hook()
注册的全局backward
hook 将在此方法注册的所有 hook 之前触发。
- 返回
一个句柄,可用于通过调用
handle.remove()
来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_full_backward_pre_hook(hook, prepend=False)[源]#
在模块上注册一个反向预钩子。
每次计算模块的梯度时,将调用此钩子。钩子应具有以下签名
hook(module, grad_output) -> tuple[Tensor] or None
The
grad_output
is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place ofgrad_output
in subsequent computations. Entries ingrad_output
will beNone
for all non-Tensor arguments.由于技术原因,当此钩子应用于模块时,其前向函数将接收传递给模块的每个张量的视图。类似地,调用者将接收模块前向函数返回的每个张量的视图。
警告
使用反向传播钩子时不允许就地修改输入,否则将引发错误。
- 参数
hook (Callable) – 要注册的用户定义钩子。
prepend (bool) – If true, the provided
hook
will be fired before all existingbackward_pre
hooks on thistorch.nn.Module
. Otherwise, the providedhook
will be fired after all existingbackward_pre
hooks on thistorch.nn.Module
. Note that globalbackward_pre
hooks registered withregister_module_full_backward_pre_hook()
will fire before all hooks registered by this method.
- 返回
一个句柄,可用于通过调用
handle.remove()
来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_load_state_dict_post_hook(hook)[源]#
注册一个后钩子,用于在模块的
load_state_dict()
被调用后运行。- 它应该具有以下签名:
hook(module, incompatible_keys) -> None
The
module
argument is the current module that this hook is registered on, and theincompatible_keys
argument is aNamedTuple
consisting of attributesmissing_keys
andunexpected_keys
.missing_keys
is alist
ofstr
containing the missing keys andunexpected_keys
is alist
ofstr
containing the unexpected keys.如果需要,可以就地修改给定的 incompatible_keys。
Note that the checks performed when calling
load_state_dict()
withstrict=True
are affected by modifications the hook makes tomissing_keys
orunexpected_keys
, as expected. Additions to either set of keys will result in an error being thrown whenstrict=True
, and clearing out both missing and unexpected keys will avoid an error.- 返回
一个句柄,可用于通过调用
handle.remove()
来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_load_state_dict_pre_hook(hook)[源]#
注册一个预钩子,用于在模块的
load_state_dict()
被调用之前运行。- 它应该具有以下签名:
hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
- 参数
hook (Callable) – 在加载状态字典之前将调用的可调用钩子。
- register_module(name, module)[源]#
Alias for
add_module()
.
- register_parameter(name, param)[源]#
向模块添加一个参数。
可以使用给定名称作为属性访问该参数。
- 参数
name (str) – 参数的名称。可以使用给定名称从此模块访问该参数
param (Parameter or None) – parameter to be added to the module. If
None
, then operations that run on parameters, such ascuda
, are ignored. IfNone
, the parameter is not included in the module’sstate_dict
.
- register_state_dict_post_hook(hook)[源]#
Register a post-hook for the
state_dict()
method.- 它应该具有以下签名:
hook(module, state_dict, prefix, local_metadata) -> None
注册的钩子可以就地修改
state_dict
。
- register_state_dict_pre_hook(hook)[源]#
Register a pre-hook for the
state_dict()
method.- 它应该具有以下签名:
hook(module, prefix, keep_vars) -> None
注册的钩子可用于在进行
state_dict
调用之前执行预处理。
- requires_grad_(requires_grad=True)[源]#
更改自动梯度是否应记录此模块中参数的操作。
此方法就地设置参数的
requires_grad
属性。此方法有助于冻结模块的一部分以进行微调或单独训练模型的一部分(例如,GAN 训练)。
See Locally disabling gradient computation for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it.
- set_extra_state(state)[源]#
设置加载的 state_dict 中包含的额外状态。
This function is called from
load_state_dict()
to handle any extra state found within the state_dict. Implement this function and a correspondingget_extra_state()
for your module if you need to store extra state within its state_dict.- 参数
state (dict) – state_dict 中的额外状态
- set_submodule(target, module, strict=False)[源]#
如果存在,设置由
target
给定的子模块,否则抛出错误。注意
If
strict
is set toFalse
(default), the method will replace an existing submodule or create a new submodule if the parent module exists. Ifstrict
is set toTrue
, the method will only attempt to replace an existing submodule and throw an error if the submodule does not exist.例如,假设您有一个
nn.Module
A
,它看起来像这样A( (net_b): Module( (net_c): Module( (conv): Conv2d(3, 3, 3) ) (linear): Linear(3, 3) ) )
(The diagram shows an
nn.Module
A
.A
has a nested submodulenet_b
, which itself has two submodulesnet_c
andlinear
.net_c
then has a submoduleconv
.)To override the
Conv2d
with a new submoduleLinear
, you could callset_submodule("net_b.net_c.conv", nn.Linear(1, 1))
wherestrict
could beTrue
orFalse
To add a new submodule
Conv2d
to the existingnet_b
module, you would callset_submodule("net_b.conv", nn.Conv2d(1, 1, 1))
.In the above if you set
strict=True
and callset_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True)
, an AttributeError will be raised becausenet_b
does not have a submodule namedconv
.- 参数
target (str) – 要查找的子模块的完全限定字符串名称。(请参阅上面的示例以了解如何指定完全限定字符串。)
module (Module) – The module to set the submodule to.
strict (bool) – If
False
, the method will replace an existing submodule or create a new submodule if the parent module exists. IfTrue
, the method will only attempt to replace an existing submodule and throw an error if the submodule doesn’t already exist.
- 引发
ValueError – If the
target
string is empty or ifmodule
is not an instance ofnn.Module
.AttributeError – If at any point along the path resulting from the
target
string the (sub)path resolves to a non-existent attribute name or an object that is not an instance ofnn.Module
.
See
torch.Tensor.share_memory_()
.- 返回类型
自我
- state_dict(*args, destination=None, prefix='', keep_vars=False)[源]#
返回一个字典,其中包含对模块整个状态的引用。
参数和持久缓冲区(例如,运行平均值)都包含在内。键是相应的参数和缓冲区名称。设置为
None
的参数和缓冲区不包含在内。注意
返回的对象是浅拷贝。它包含对模块参数和缓冲区的引用。
警告
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.警告
请避免使用参数
destination
,因为它不是为最终用户设计的。- 参数
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- 返回
包含模块整体状态的字典
- 返回类型
示例
>>> module.state_dict().keys() ['bias', 'weight']
- to(*args, **kwargs)[源]#
移动和/或转换参数和缓冲区。
这可以这样调用
- to(device=None, dtype=None, non_blocking=False)[源]
- to(dtype, non_blocking=False)[源]
- to(tensor, non_blocking=False)[源]
- to(memory_format=torch.channels_last)[源]
Its signature is similar to
torch.Tensor.to()
, but only accepts floating point or complexdtype
s. In addition, this method will only cast the floating point or complex parameters and buffers todtype
(if given). The integral parameters and buffers will be moveddevice
, if that is given, but with dtypes unchanged. Whennon_blocking
is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.有关示例,请参阅下文。
注意
此方法就地修改模块。
- 参数
device (
torch.device
) – the desired device of the parameters and buffers in this moduledtype (
torch.dtype
) – the desired floating point or complex dtype of the parameters and buffers in this moduletensor (torch.Tensor) – 张量,其 dtype 和设备是此模块中所有参数和缓冲区的所需 dtype 和设备
memory_format (
torch.memory_format
) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)
- 返回
self
- 返回类型
示例
>>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
- to_empty(*, device, recurse=True)[源]#
将参数和缓冲区移动到指定设备,而不复制存储。
- 参数
device (
torch.device
) – The desired device of the parameters and buffers in this module.recurse (bool) – Whether parameters and buffers of submodules should be recursively moved to the specified device.
- 返回
self
- 返回类型
- train(mode=True)[源]#
将模块设置为训练模式。
This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g.
Dropout
,BatchNorm
, etc.
- xpu(device=None)[源]#
将所有模型参数和缓冲区移动到 XPU。
这也会使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 XPU 上,则应在构建优化器之前调用它。
注意
此方法就地修改模块。
- zero_grad(set_to_none=True)[源]#
重置所有模型参数的梯度。
See similar function under
torch.optim.Optimizer
for more context.- 参数
set_to_none (bool) – instead of setting to zero, set the grads to None. See
torch.optim.Optimizer.zero_grad()
for details.