ParameterList#
- class torch.nn.ParameterList(values=None)[source]#
将参数保存在一个列表中。
ParameterList的用法类似于常规的 Python 列表,但其中的Parameter类型 Tensors 会被正确地注册,并且所有Module方法都能访问到它们。请注意,构造函数、列表元素的赋值、
append()方法和extend()方法都会将任何Tensor转换为Parameter。- 参数
parameters (iterable, optional) – 要添加到列表中的元素的迭代器。
示例
class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.params = nn.ParameterList( [nn.Parameter(torch.randn(10, 10)) for i in range(10)] ) def forward(self, x): # ParameterList can act as an iterable, or be indexed using ints for i, p in enumerate(self.params): x = self.params[i // 2].mm(x) + p.mm(x) return x