注意
跳转至页面底部下载完整示例代码。
神经正切核#
创建日期:2023年3月15日 | 最后更新:2025年9月19日 | 最后验证:未验证
神经正切核 (NTK) 是描述 神经网络在训练过程中如何演化 的一种核函数。近年来,围绕该课题已有大量研究 发表。本教程受 JAX 版 NTK 实现(详见 快速有限宽度神经正切核)的启发,演示了如何利用 PyTorch 的可组合函数变换工具 torch.func 轻松计算这一量。
注意
本教程需要 PyTorch 2.6.0 或更高版本。
设置#
首先进行一些准备工作。让我们定义一个简单的 CNN 模型,并计算其 NTK。
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev
if torch.accelerator.is_available() and torch.accelerator.device_count() > 0:
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, (3, 3))
self.conv2 = nn.Conv2d(32, 32, (3, 3))
self.conv3 = nn.Conv2d(32, 32, (3, 3))
self.fc = nn.Linear(21632, 10)
def forward(self, x):
x = self.conv1(x)
x = x.relu()
x = self.conv2(x)
x = x.relu()
x = self.conv3(x)
x = x.flatten(1)
x = self.fc(x)
return x
生成一些随机数据。
x_train = torch.randn(20, 3, 32, 32, device=device)
x_test = torch.randn(5, 3, 32, 32, device=device)
创建模型的功能版本#
torch.func 变换作用于函数。特别是,为了计算 NTK,我们需要一个能够接收模型参数和单个输入(而不是一批输入!)并返回单个输出的函数。
我们将使用 torch.func.functional_call,它允许我们使用不同的参数/缓冲区来调用 nn.Module,从而协助完成第一步。
请记住,模型最初编写时是为了接收一批输入数据点。在我们的 CNN 示例中,不存在批次间的操作。也就是说,批次中的每个数据点都独立于其他数据点。基于此假设,我们可以轻松生成一个在单个数据点上评估模型的函数。
net = CNN().to(device)
# Detaching the parameters because we won't be calling Tensor.backward().
params = {k: v.detach() for k, v in net.named_parameters()}
def fnet_single(params, x):
return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)
计算 NTK:方法 1(雅可比矩阵收缩)#
我们准备好计算经验 NTK 了。对于两个数据点 \(x_1\) 和 \(x_2\),经验 NTK 定义为模型在 \(x_1\) 处评估的雅可比矩阵与模型在 \(x_2\) 处评估的雅可比矩阵之间的矩阵积:
在批处理情况下,如果 \(x_1\) 是一批数据点,\(x_2\) 也是一批数据点,那么我们需要计算 \(x_1\) 和 \(x_2\) 中所有数据点组合的雅可比矩阵之间的矩阵积。
第一种方法就是如此——计算两个雅可比矩阵并进行收缩。以下是如何在批处理情况下计算 NTK 的方法:
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = jac1.values()
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = jac2.values()
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
print(result.shape)
torch.Size([20, 5, 10, 10])
在某些情况下,你可能只需要该量的对角线或迹,特别是如果你预先知道网络架构导致 NTK 的非对角元素可以近似为零。调整上述函数来实现这一点非常容易。
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = jac1.values()
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = jac2.values()
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
einsum_expr = None
if compute == 'full':
einsum_expr = 'Naf,Mbf->NMab'
elif compute == 'trace':
einsum_expr = 'Naf,Maf->NM'
elif compute == 'diagonal':
einsum_expr = 'Naf,Maf->NMa'
else:
assert False
result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')
print(result.shape)
torch.Size([20, 5])
此方法的渐近时间复杂度为 \(N O [FP]\)(计算雅可比矩阵的时间)+ \(N^2 O^2 P\)(收缩雅可比矩阵的时间),其中 \(N\) 是 \(x_1\) 和 \(x_2\) 的批大小,\(O\) 是模型的输出大小,\(P\) 是参数总数,\([FP]\) 是模型单次前向传播的开销。详情请参阅 快速有限宽度神经正切核 中的第 3.2 节。
计算 NTK:方法 2(NTK-向量积)#
接下来我们要讨论的是使用 NTK-向量积计算 NTK 的方法。
此方法将 NTK 重新表述为应用于大小为 \(O\times O\) 的单位矩阵 \(I_O\)(其中 \(O\) 是模型输出大小)各列的 NTK-向量积堆栈:
其中 \(e_o\in \mathbb{R}^O\) 是单位矩阵 \(I_O\) 的列向量。
令 \(\textrm{vjp}_o = J_{net}^T(x_2) e_o\)。我们可以利用向量-雅可比积 (VJP) 来计算它。
现在,考虑 \(J_{net}(x_1) \textrm{vjp}_o\)。这是一个雅可比-向量积 (JVP)!
最后,我们可以使用
vmap并行处理 \(I_O\) 的所有列 \(e_o\),运行上述计算。
这表明我们可以结合反向模式自动微分(计算 VJP)和前向模式自动微分(计算 JVP)来计算 NTK。
让我们编写代码实现它。
def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
def get_ntk(x1, x2):
def func_x1(params):
return func(params, x1)
def func_x2(params):
return func(params, x2)
output, vjp_fn = vjp(func_x1, params)
def get_ntk_slice(vec):
# This computes ``vec @ J(x2).T``
# `vec` is some unit vector (a single slice of the Identity matrix)
vjps = vjp_fn(vec)
# This computes ``J(X1) @ vjps``
_, jvps = jvp(func_x2, (params,), vjps)
return jvps
# Here's our identity matrix
basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
return vmap(get_ntk_slice)(basis)
# ``get_ntk(x1, x2)`` computes the NTK for a single data point x1, x2
# Since the x1, x2 inputs to ``empirical_ntk_ntk_vps`` are batched,
# we actually wish to compute the NTK between every pair of data points
# between {x1} and {x2}. That's what the ``vmaps`` here do.
result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)
if compute == 'full':
return result
if compute == 'trace':
return torch.einsum('NMKK->NM', result)
if compute == 'diagonal':
return torch.einsum('NMKK->NMK', result)
# Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy
with torch.backends.cudnn.flags(allow_tf32=False):
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)
assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)
empirical_ntk_ntk_vps 的代码看起来就像是上述数学公式的直接翻译!这展示了函数变换的强大之处:如果仅使用 torch.autograd.grad,想要写出高效的版本会非常困难。
此方法的渐近时间复杂度为 \(N^2 O [FP]\),其中 \(N\) 为 \(x_1\) 和 \(x_2\) 的批大小,\(O\) 为模型输出大小,\([FP]\) 为单次前向传播开销。因此,该方法执行的前向传播次数比方法 1(雅可比矩阵收缩,\(N^2 O\) vs \(N O\))更多,但完全避免了收缩开销(没有 \(N^2 O^2 P\) 项,其中 \(P\) 是总参数量)。因此,当 \(O P\) 相对于 \([FP]\) 较大时(例如具有大量输出 \(O\) 的全连接模型),此方法更可取。在内存消耗方面,两种方法相当。详情请参阅 快速有限宽度神经正切核 中的第 3.3 节。
脚本运行总时间:(0 分 0.733 秒)