评价此页

torch.nn.factory_kwargs#

torch.nn.factory_kwargs(kwargs)[源码]#

返回规范化的 factory kwargs 字典。

给定 kwargs,返回一个可以直接传递给 factory 函数(如 torch.empty)的规范化 factory kwargs 字典,如果存在不受识别的 kwargs,则会报错。

此函数使编写类似以下代码的代码变得简单

class MyModule(nn.Module):
    def __init__(self, **kwargs):
        factory_kwargs = torch.nn.factory_kwargs(kwargs)
        self.weight = Parameter(torch.empty(10, **factory_kwargs))

为什么你应该使用这个函数而不是直接传递 kwargs

1. 此函数会进行错误验证,因此如果存在意外的 kwargs,我们会立即报告错误,而不是将其推迟到 factory 调用。2. 此函数支持一个特殊的 factory_kwargs 参数,可用于显式指定一个将被用于 factory 函数的 kwarg,以防其中一个 factory kwarg 与签名中已存在的参数冲突(例如,在签名 def f(dtype, **kwargs) 中,你可以通过指定 dtype 来为 factory 函数指定 dtype,这与 dtype 参数不同,即 f(dtype1, factory_kwargs={"dtype": dtype2}))。