评价此页

ParametrizationList#

class torch.nn.utils.parametrize.ParametrizationList(modules, original, unsafe=False)[源代码]#

一个顺序容器,用于保存和管理参数化 torch.nn.Module 的原始参数或缓冲区。

module[tensor_name] 使用 register_parametrization() 进行参数化时,module.parametrizations[tensor_name] 的类型就是 ParametrizationList

如果第一个注册的参数化具有返回一个张量的 right_inverse 或不具有 right_inverse(在这种情况下,我们假设 right_inverse 是恒等函数),它将以 original 的名称保存该张量。如果它有一个返回多个张量的 right_inverse,这些张量将分别注册为 original0original1,依此类推。

警告

register_parametrization() 会在内部使用此类。此处记录是为了完整性。用户不应实例化此类。

参数
  • modules (sequence) – 代表参数化的模块序列

  • original (ParameterTensor) – 被参数化的参数或缓冲区

  • unsafe (bool) – 一个布尔标志,表示参数化是否可能改变张量的 dtype 和形状。默认为 False。警告:在注册时,参数化未进行一致性检查。启用此标志需自担风险。

right_inverse(value)[源代码]#

按注册的逆序调用参数化列表中的 right_inverse 方法。

然后,如果 right_inverse 输出一个张量,则将其存储在 self.original 中;如果输出多个张量,则分别存储在 self.original0self.original1 等中。

参数

value (Tensor) – 用于初始化模块的值