ScriptModule#
- class torch.jit.ScriptModule[source]#
Wrapper for C++ torch::jit::Module with methods, attributes, and parameters.
C++
torch::jit::Module的封装。ScriptModule包含方法、属性、参数和常量。这些可以与普通nn.Module相同的方式访问。- apply(fn)[source]#
将
fn递归应用于每个子模块(由.children()返回)以及自身。典型用途包括初始化模型的参数(另请参阅 torch.nn.init)。
- 参数
fn (
Module-> None) – 要应用于每个子模块的函数- 返回
self
- 返回类型
示例
>>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )
- buffers(recurse=True)[source]#
返回模块缓冲区的迭代器。
- 参数
recurse (bool) – 如果为 True,则会生成此模块及所有子模块的缓冲区。否则,只生成此模块的直接成员缓冲区。
- 生成
torch.Tensor – 模块缓冲区
- 返回类型
示例
>>> for buf in model.buffers(): >>> print(type(buf), buf.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- property code#
Return a pretty-printed representation (as valid Python syntax) of the internal graph for the
forwardmethod.
- property code_with_constants#
Return a tuple.
Returns a tuple of
[0] a pretty-printed representation (as valid Python syntax) of the internal graph for the
forwardmethod. See code. [1] a ConstMap following the CONSTANT.cN format of the output in [0]. The indices in the [0] output are keys to the underlying constant’s values.
- compile(*args, **kwargs)[source]#
使用
torch.compile()编译此模块的 forward。此模块的 __call__ 方法已编译,所有参数将按原样传递给
torch.compile()。有关此函数参数的详细信息,请参阅
torch.compile()。
- cuda(device=None)[source]#
将所有模型参数和缓冲区移动到 GPU。
这也会使相关的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 GPU 上,则应在构建优化器之前调用此函数。
注意
此方法就地修改模块。
- eval()[source]#
将模块设置为评估模式。
这仅对某些模块有影响。有关模块在训练/评估模式下的行为,例如它们是否受影响(如
Dropout、BatchNorm等),请参阅具体模块的文档。This is equivalent with
self.train(False).请参阅 局部禁用梯度计算,了解 .eval() 与一些可能与之混淆的类似机制之间的比较。
- 返回
self
- 返回类型
- get_buffer(target)[source]#
返回由
target给定的缓冲区(如果存在),否则抛出错误。有关此方法功能的更详细解释以及如何正确指定
target,请参阅get_submodule的文档字符串。- 参数
target (str) – 要查找的缓冲区的完整限定字符串名称。(有关如何指定完整限定字符串,请参阅
get_submodule。)- 返回
由
target引用的缓冲区- 返回类型
- 引发
AttributeError – 如果目标字符串引用了无效路径或解析为非缓冲区项。
- get_extra_state()[source]#
返回要包含在模块 state_dict 中的任何额外状态。
Implement this and a corresponding
set_extra_state()for your module if you need to store extra state. This function is called when building the module’s state_dict().注意,为了保证 state_dict 的序列化工作正常,额外状态应该是可被 pickle 的。我们仅为 Tensors 的序列化提供向后兼容性保证;其他对象的序列化形式若发生变化,可能导致向后兼容性中断。
- 返回
要存储在模块 state_dict 中的任何额外状态
- 返回类型
- get_parameter(target)[source]#
如果存在,返回由
target给定的参数,否则抛出错误。有关此方法功能的更详细解释以及如何正确指定
target,请参阅get_submodule的文档字符串。- 参数
target (str) – 要查找的参数的完整限定字符串名称。(有关如何指定完整限定字符串,请参阅
get_submodule。)- 返回
由
target引用的参数- 返回类型
torch.nn.Parameter
- 引发
AttributeError – 如果目标字符串引用了无效路径或解析为非
nn.Parameter项。
- get_submodule(target)[source]#
如果存在,返回由
target给定的子模块,否则抛出错误。例如,假设您有一个
nn.ModuleA,它看起来像这样A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) )(图示了一个
nn.ModuleA。A包含一个嵌套子模块net_b,该子模块本身有两个子模块net_c和linear。net_c随后又有一个子模块conv。)要检查是否存在
linear子模块,可以调用get_submodule("net_b.linear")。要检查是否存在conv子模块,可以调用get_submodule("net_b.net_c.conv")。get_submodule的运行时复杂度受target中模块嵌套深度的限制。与named_modules的查询相比,后者的复杂度是按传递模块数量计算的 O(N)。因此,对于简单地检查某个子模块是否存在,应始终使用get_submodule。- 参数
target (str) – 要查找的子模块的完整限定字符串名称。(如上例所示,如何指定完整限定字符串。)
- 返回
由
target引用的子模块- 返回类型
- 引发
AttributeError – 如果在
target字符串解析出的路径中的任何一点,(子)路径解析为一个不存在的属性名或一个非nn.Module实例的对象。
- property graph#
Return a string representation of the internal graph for the
forwardmethod.
- property inlined_graph#
Return a string representation of the internal graph for the
forwardmethod.This graph will be preprocessed to inline all function and method calls.
- ipu(device=None)[source]#
将所有模型参数和缓冲区移动到 IPU。
这也会使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 IPU 上,则应在构建优化器之前调用它。
注意
此方法就地修改模块。
- load_state_dict(state_dict, strict=True, assign=False)[source]#
Copy parameters and buffers from
state_dictinto this module and its descendants.If
strictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.警告
If
assignisTruethe optimizer must be created after the call toload_state_dictunlessget_swap_module_params_on_conversion()isTrue.- 参数
state_dict (dict) – 包含参数和持久缓冲区的字典。
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:Trueassign (bool, optional) – 当设置为
False时,将保留当前模块中张量的属性;设置为True时,将保留 state dict 中张量的属性。唯一的例外是Parameter的requires_grad字段,此时将保留模块中的值。默认为False。
- 返回
missing_keys是一个包含此模块期望但在提供的
state_dict中缺失的任何键的字符串列表。
unexpected_keys是一个字符串列表,包含此模块不期望但在提供的
state_dict中存在的键。
- 返回类型
NamedTuple,包含missing_keys和unexpected_keys字段。
注意
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
- modules()[source]#
返回网络中所有模块的迭代器。
注意
重复的模块只返回一次。在以下示例中,
l只返回一次。示例
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True)
- mtia(device=None)[source]#
将所有模型参数和缓冲区移动到 MTIA。
这也会使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 MTIA 上,则应在构建优化器之前调用它。
注意
此方法就地修改模块。
- named_buffers(prefix='', recurse=True, remove_duplicate=True)[source]#
返回模块缓冲区上的迭代器,同时生成缓冲区的名称和缓冲区本身。
- 参数
- 生成
(str, torch.Tensor) – 包含名称和缓冲区的元组
- 返回类型
示例
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_children()[source]#
返回对直接子模块的迭代器,生成模块的名称和模块本身。
示例
>>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module)
- named_modules(memo=None, prefix='', remove_duplicate=True)[source]#
返回网络中所有模块的迭代器,同时生成模块的名称和模块本身。
- 参数
- 生成
(str, Module) – 名称和模块的元组
注意
重复的模块只返回一次。在以下示例中,
l只返回一次。示例
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
- named_parameters(prefix='', recurse=True, remove_duplicate=True)[source]#
返回模块参数的迭代器,同时生成参数的名称和参数本身。
- 参数
- 生成
(str, Parameter) – 包含名称和参数的元组
- 返回类型
示例
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- parameters(recurse=True)[source]#
返回模块参数的迭代器。
这通常传递给优化器。
- 参数
recurse (bool) – 如果为 True,则会生成此模块及所有子模块的参数。否则,只生成此模块的直接成员参数。
- 生成
Parameter – 模块参数
- 返回类型
示例
>>> for param in model.parameters(): >>> print(type(param), param.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- register_backward_hook(hook)[source]#
在模块上注册一个反向传播钩子。
This function is deprecated in favor of
register_full_backward_hook()and the behavior of this function will change in future versions.- 返回
一个句柄,可用于通过调用
handle.remove()来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_buffer(name, tensor, persistent=True)[source]#
向模块添加一个缓冲区。
This is typically used to register a buffer that should not be considered a model parameter. For example, BatchNorm’s
running_meanis not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by settingpersistenttoFalse. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’sstate_dict.可以使用给定名称作为属性访问缓冲区。
- 参数
name (str) – 缓冲区的名称。缓冲区可以通过给定名称从该模块访问。
tensor (Tensor or None) – buffer to be registered. If
None, then operations that run on buffers, such ascuda, are ignored. IfNone, the buffer is not included in the module’sstate_dict.persistent (bool) – whether the buffer is part of this module’s
state_dict.
示例
>>> self.register_buffer('running_mean', torch.zeros(num_features))
- register_forward_hook(hook, *, prepend=False, with_kwargs=False, always_call=False)[source]#
在模块上注册一个前向钩子。
The hook will be called every time after
forward()has computed an output.If
with_kwargsisFalseor not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to theforward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called afterforward()is called. The hook should have the following signaturehook(module, args, output) -> None or modified output
如果
with_kwargs为True,则前向钩子将接收传递给 forward 函数的kwargs,并需要返回可能已修改的输出。钩子应该具有以下签名hook(module, args, kwargs, output) -> None or modified output
- 参数
hook (Callable) – 用户定义的待注册钩子。
prepend (bool) – If
True, the providedhookwill be fired before all existingforwardhooks on thistorch.nn.Module. Otherwise, the providedhookwill be fired after all existingforwardhooks on thistorch.nn.Module. Note that globalforwardhooks registered withregister_module_forward_hook()will fire before all hooks registered by this method. Default:Falsewith_kwargs (bool) – 如果为
True,则hook将接收传递给 forward 函数的 kwargs。默认为False。always_call (bool) – 如果为
True,则无论调用 Module 时是否发生异常,都将运行hook。默认为False。
- 返回
一个句柄,可用于通过调用
handle.remove()来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_forward_pre_hook(hook, *, prepend=False, with_kwargs=False)[source]#
在模块上注册一个前向预钩子。
The hook will be called every time before
forward()is invoked.如果
with_kwargs为 false 或未指定,则输入仅包含传递给模块的位置参数。关键字参数不会传递给钩子,而只会传递给forward。钩子可以修改输入。用户可以返回一个元组或单个修改后的值。我们将把值包装成一个元组,如果返回的是单个值(除非该值本身就是元组)。钩子应该具有以下签名hook(module, args) -> None or modified input
如果
with_kwargs为 true,则前向预钩子将接收传递给 forward 函数的 kwargs。如果钩子修改了输入,则应该返回 args 和 kwargs。钩子应该具有以下签名hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
- 参数
hook (Callable) – 用户定义的待注册钩子。
prepend (bool) – If true, the provided
hookwill be fired before all existingforward_prehooks on thistorch.nn.Module. Otherwise, the providedhookwill be fired after all existingforward_prehooks on thistorch.nn.Module. Note that globalforward_prehooks registered withregister_module_forward_pre_hook()will fire before all hooks registered by this method. Default:Falsewith_kwargs (bool) – 如果为 true,则
hook将接收传递给 forward 函数的 kwargs。默认为False。
- 返回
一个句柄,可用于通过调用
handle.remove()来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_full_backward_hook(hook, prepend=False)[source]#
在模块上注册一个反向传播钩子。
每次计算相对于模块的梯度时,将调用此钩子,其触发规则如下:
通常,钩子在计算相对于模块输入的梯度时触发。
如果模块输入都不需要梯度,则在计算相对于模块输出的梯度时触发钩子。
如果模块输出都不需要梯度,则钩子将不触发。
钩子应具有以下签名
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
The
grad_inputandgrad_outputare tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place ofgrad_inputin subsequent computations.grad_inputwill only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries ingrad_inputandgrad_outputwill beNonefor all non-Tensor arguments.由于技术原因,当此钩子应用于模块时,其前向函数将接收传递给模块的每个张量的视图。类似地,调用者将接收模块前向函数返回的每个张量的视图。
警告
使用反向传播钩子时不允许就地修改输入或输出,否则将引发错误。
- 参数
hook (Callable) – 要注册的用户定义钩子。
prepend (bool) – If true, the provided
hookwill be fired before all existingbackwardhooks on thistorch.nn.Module. Otherwise, the providedhookwill be fired after all existingbackwardhooks on thistorch.nn.Module. Note that globalbackwardhooks registered withregister_module_full_backward_hook()will fire before all hooks registered by this method.
- 返回
一个句柄,可用于通过调用
handle.remove()来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_full_backward_pre_hook(hook, prepend=False)[source]#
在模块上注册一个反向预钩子。
每次计算模块的梯度时,将调用此钩子。钩子应具有以下签名
hook(module, grad_output) -> tuple[Tensor] or None
grad_output是一个元组。钩子不应修改其参数,但可以选择返回一个新的输出梯度,该梯度将取代grad_output用于后续计算。对于所有非 Tensor 参数,grad_output中的条目将为None。由于技术原因,当此钩子应用于模块时,其前向函数将接收传递给模块的每个张量的视图。类似地,调用者将接收模块前向函数返回的每个张量的视图。
警告
使用反向传播钩子时不允许就地修改输入,否则将引发错误。
- 参数
hook (Callable) – 要注册的用户定义钩子。
prepend (bool) – If true, the provided
hookwill be fired before all existingbackward_prehooks on thistorch.nn.Module. Otherwise, the providedhookwill be fired after all existingbackward_prehooks on thistorch.nn.Module. Note that globalbackward_prehooks registered withregister_module_full_backward_pre_hook()will fire before all hooks registered by this method.
- 返回
一个句柄,可用于通过调用
handle.remove()来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_load_state_dict_post_hook(hook)[source]#
注册一个后钩子,用于在模块的
load_state_dict()被调用后运行。- 它应该具有以下签名:
hook(module, incompatible_keys) -> None
The
moduleargument is the current module that this hook is registered on, and theincompatible_keysargument is aNamedTupleconsisting of attributesmissing_keysandunexpected_keys.missing_keysis alistofstrcontaining the missing keys andunexpected_keysis alistofstrcontaining the unexpected keys.如果需要,可以就地修改给定的 incompatible_keys。
Note that the checks performed when calling
load_state_dict()withstrict=Trueare affected by modifications the hook makes tomissing_keysorunexpected_keys, as expected. Additions to either set of keys will result in an error being thrown whenstrict=True, and clearing out both missing and unexpected keys will avoid an error.- 返回
一个句柄,可用于通过调用
handle.remove()来移除添加的钩子- 返回类型
torch.utils.hooks.RemovableHandle
- register_load_state_dict_pre_hook(hook)[source]#
注册一个预钩子,用于在模块的
load_state_dict()被调用之前运行。- 它应该具有以下签名:
hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
- 参数
hook (Callable) – 在加载状态字典之前将调用的可调用钩子。
- register_module(name, module)[source]#
Alias for
add_module().
- register_parameter(name, param)[source]#
向模块添加一个参数。
可以使用给定名称作为属性访问该参数。
- 参数
name (str) – 参数的名称。参数可以通过给定名称从该模块访问。
param (Parameter or None) – parameter to be added to the module. If
None, then operations that run on parameters, such ascuda, are ignored. IfNone, the parameter is not included in the module’sstate_dict.
- register_state_dict_post_hook(hook)[source]#
Register a post-hook for the
state_dict()method.- 它应该具有以下签名:
hook(module, state_dict, prefix, local_metadata) -> None
注册的钩子可以就地修改
state_dict。
- register_state_dict_pre_hook(hook)[source]#
Register a pre-hook for the
state_dict()method.- 它应该具有以下签名:
hook(module, prefix, keep_vars) -> None
注册的钩子可用于在进行
state_dict调用之前执行预处理。
- requires_grad_(requires_grad=True)[source]#
更改自动梯度是否应记录此模块中参数的操作。
此方法就地设置参数的
requires_grad属性。此方法有助于冻结模块的一部分以进行微调或单独训练模型的一部分(例如,GAN 训练)。
请参阅 局部禁用梯度计算,了解 .requires_grad_() 与一些可能与之混淆的类似机制之间的比较。
- save(f, **kwargs)[source]#
Save with a file-like object.
save(f, _extra_files={})
See
torch.jit.savewhich accepts a file-like object. This function, torch.save(), converts the object to a string, treating it as a path. DO NOT confuse these two functions when it comes to the ‘f’ parameter functionality.
- set_extra_state(state)[source]#
设置加载的 state_dict 中包含的额外状态。
This function is called from
load_state_dict()to handle any extra state found within the state_dict. Implement this function and a correspondingget_extra_state()for your module if you need to store extra state within its state_dict.- 参数
state (dict) – 来自 state_dict 的额外状态。
- set_submodule(target, module, strict=False)[source]#
如果存在,设置由
target给定的子模块,否则抛出错误。注意
如果
strict设置为False(默认),该方法将替换现有子模块或在父模块存在的情况下创建新子模块。如果strict设置为True,该方法将仅尝试替换现有子模块,并在子模块不存在时引发错误。例如,假设您有一个
nn.ModuleA,它看起来像这样A( (net_b): Module( (net_c): Module( (conv): Conv2d(3, 3, 3) ) (linear): Linear(3, 3) ) )(图示了一个
nn.ModuleA。A包含一个嵌套子模块net_b,该子模块本身有两个子模块net_c和linear。net_c随后又有一个子模块conv。)要用一个新的
Linear子模块覆盖Conv2d,可以调用set_submodule("net_b.net_c.conv", nn.Linear(1, 1)),其中strict可以是True或False。要将一个新的
Conv2d子模块添加到现有的net_b模块中,可以调用set_submodule("net_b.conv", nn.Conv2d(1, 1, 1))。在上面,如果设置
strict=True并调用set_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True),则会引发 AttributeError,因为net_b中不存在名为conv的子模块。- 参数
- 引发
ValueError – 如果
target字符串为空,或者module不是nn.Module的实例。AttributeError – 如果在
target字符串解析出的路径中的任何一点,(子)路径解析为一个不存在的属性名或一个非nn.Module实例的对象。
请参阅
torch.Tensor.share_memory_()。- 返回类型
自我
- state_dict(*args, destination=None, prefix='', keep_vars=False)[source]#
返回一个字典,其中包含对模块整个状态的引用。
参数和持久缓冲区(例如,运行平均值)都包含在内。键是相应的参数和缓冲区名称。设置为
None的参数和缓冲区不包含在内。注意
返回的对象是浅拷贝。它包含对模块参数和缓冲区的引用。
警告
当前
state_dict()还接受destination、prefix和keep_vars的位置参数,顺序为。但是,这正在被弃用,并且在未来的版本中将强制使用关键字参数。警告
请避免使用参数
destination,因为它不是为最终用户设计的。- 参数
- 返回
包含模块整体状态的字典
- 返回类型
示例
>>> module.state_dict().keys() ['bias', 'weight']
- to(*args, **kwargs)[source]#
移动和/或转换参数和缓冲区。
这可以这样调用
- to(device=None, dtype=None, non_blocking=False)[source]
- to(dtype, non_blocking=False)[源码]
- to(tensor, non_blocking=False)[源码]
- to(memory_format=torch.channels_last)[源码]
Its signature is similar to
torch.Tensor.to(), but only accepts floating point or complexdtypes. In addition, this method will only cast the floating point or complex parameters and buffers todtype(if given). The integral parameters and buffers will be moveddevice, if that is given, but with dtypes unchanged. Whennon_blockingis set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.有关示例,请参阅下文。
注意
此方法就地修改模块。
- 参数
device (
torch.device) – 此模块中的参数和缓冲区的目标设备dtype (
torch.dtype) – 此模块中的参数和缓冲区的目标浮点数或复数dtypetensor (torch.Tensor) – 张量,其 dtype 和设备是此模块中所有参数和缓冲区的所需 dtype 和设备
memory_format (
torch.memory_format) – 此模块中 4D 参数和缓冲区的目标内存格式(仅限关键字参数)
- 返回
self
- 返回类型
示例
>>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
- to_empty(*, device, recurse=True)[source]#
将参数和缓冲区移动到指定设备,而不复制存储。
- 参数
device (
torch.device) – 此模块中的参数和缓冲区的目标设备。recurse (bool) – 是否递归地将子模块的参数和缓冲区移动到指定设备。
- 返回
self
- 返回类型
- train(mode=True)[source]#
将模块设置为训练模式。
This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g.
Dropout,BatchNorm, etc. – 这只对某些模块有影响。有关其在训练/评估模式下的行为的详细信息,例如它们是否受影响,请参阅特定模块的文档,例如Dropout、BatchNorm等。
- xpu(device=None)[source]#
将所有模型参数和缓冲区移动到 XPU。
这也会使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 XPU 上,则应在构建优化器之前调用它。
注意
此方法就地修改模块。
- zero_grad(set_to_none=True)[source]#
重置所有模型参数的梯度。
请参阅
torch.optim.Optimizer下的类似函数以获取更多上下文。- 参数
set_to_none (bool) – 不设置为零,而是将梯度设置为 None。有关详细信息,请参阅
torch.optim.Optimizer.zero_grad()。