torch.nn.utils.skip_init#
- torch.nn.utils.skip_init(module_cls, *args, **kwargs)[源代码]#
给定一个模块类对象和参数/关键字参数,在不初始化参数/缓冲区的情况下实例化模块。
这在使用默认初始化不必要时,对于初始化过程缓慢或将执行自定义初始化的情况非常有用。由于此函数的实现方式,存在一些注意事项:
1. 模块在其构造函数中必须接受一个 device 参数,该参数将传递给构造期间创建的任何参数或缓冲区。
2. 模块在其构造函数中不得对参数执行除初始化(即来自
torch.nn.init
的函数)之外的任何计算。如果满足这些条件,则模块可以实例化,其参数/缓冲区值未初始化,就像使用
torch.empty()
创建的一样。- 参数
module_cls – 类对象;应为
torch.nn.Module
的子类args – 要传递给模块构造函数的参数
kwargs – 要传递给模块构造函数的关键字参数
- 返回
实例化了具有未初始化参数/缓冲区的模块
示例
>>> import torch >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) >>> m.weight Parameter containing: tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], requires_grad=True) >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) >>> m2.weight Parameter containing: tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, 4.5915e-41]], requires_grad=True)