评价此页

torch.nn.utils.parametrize.cached#

torch.nn.utils.parametrize.cached()[source]#

上下文管理器,可在通过 register_parametrization() 注册的参数化对象中启用缓存系统。

当此上下文管理器处于活动状态时,参数化对象的计算值和缓存值将在首次需要时进行计算和缓存。缓存的值在退出上下文管理器时被丢弃。

当在正向传播中使用参数化参数一次以上时,此功能非常有用。例如,当参数化 RNN 的循环核或共享权重时。

激活缓存的最简单方法是在训练和评估中包装神经网络的向前传递。

import torch.nn.utils.parametrize as P

...
with P.cached():
    output = model(inputs)

也可以包装多次使用参数化张量的模块的部分。例如,具有参数化循环核的 RNN 的循环

with P.cached():
    for x in xs:
        out_rnn = self.rnn_cell(x, out_rnn)