torch.nn.factory_kwargs#
- torch.nn.factory_kwargs(kwargs)[源码]#
返回规范化的 factory kwargs 字典。
给定 kwargs,返回一个可以直接传递给 factory 函数(如 torch.empty)的规范化 factory kwargs 字典,如果存在不受识别的 kwargs,则会报错。
此函数使编写类似以下代码的代码变得简单
class MyModule(nn.Module): def __init__(self, **kwargs): factory_kwargs = torch.nn.factory_kwargs(kwargs) self.weight = Parameter(torch.empty(10, **factory_kwargs))
为什么你应该使用这个函数而不是直接传递 kwargs?
1. 此函数会进行错误验证,因此如果存在意外的 kwargs,我们会立即报告错误,而不是将其推迟到 factory 调用。2. 此函数支持一个特殊的 factory_kwargs 参数,可用于显式指定一个将被用于 factory 函数的 kwarg,以防其中一个 factory kwarg 与签名中已存在的参数冲突(例如,在签名
def f(dtype, **kwargs)
中,你可以通过指定dtype
来为 factory 函数指定dtype
,这与 dtype 参数不同,即f(dtype1, factory_kwargs={"dtype": dtype2})
)。