torch.nn.utils.parametrize.cached#
- torch.nn.utils.parametrize.cached()[源代码]#
当使用
register_parametrization()
注册的参数化时,启用缓存系统的上下文管理器。当此上下文管理器处于活动状态时,参数化对象的值将在首次需要时计算并缓存。缓存的值将在离开上下文管理器时被丢弃。
当在前向传播中使用参数化参数超过一次时,此功能很有用。例如,当参数化 RNN 的循环核或共享权重时。
激活缓存的最简单方法是在神经网络的前向传播中包装。
import torch.nn.utils.parametrize as P ... with P.cached(): output = model(inputs)
在训练和评估中。也可以包装模块中多次使用参数化张量的部分。例如,具有参数化循环核的 RNN 的循环。
with P.cached(): for x in xs: out_rnn = self.rnn_cell(x, out_rnn)