评价此页

神经切线核#

创建于: 2023年3月15日 | 最后更新: 2025年9月19日 | 最后验证: 未验证

神经切线核 (NTK) 是一个描述 神经网络在训练过程中如何演变 的核。近年来,围绕它进行了大量的研究 。本教程受 JAX 中 NTK 实现的启发 (请参阅 Fast Finite Width Neural Tangent Kernel 以获取详细信息),演示了如何使用 PyTorch 的可组合函数变换 torch.func 轻松计算此量。

注意

本教程需要 PyTorch 2.6.0 或更高版本。

设置#

首先,进行一些设置。让我们定义一个简单的 CNN,我们希望计算其 NTK。

import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev

if torch.accelerator.is_available() and torch.accelerator.device_count() > 0:
    device = torch.accelerator.current_accelerator()
else:
    device = torch.device("cpu")


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, (3, 3))
        self.conv2 = nn.Conv2d(32, 32, (3, 3))
        self.conv3 = nn.Conv2d(32, 32, (3, 3))
        self.fc = nn.Linear(21632, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = x.relu()
        x = self.conv2(x)
        x = x.relu()
        x = self.conv3(x)
        x = x.flatten(1)
        x = self.fc(x)
        return x

让我们生成一些随机数据

x_train = torch.randn(20, 3, 32, 32, device=device)
x_test = torch.randn(5, 3, 32, 32, device=device)

创建模型的函数版本#

torch.func 变换作用于函数。特别是,为了计算 NTK,我们需要一个接受模型参数和单个输入 (而不是输入批次!) 并返回单个输出的函数。

我们将使用 torch.func.functional_call,它允许我们使用不同的参数/缓冲区调用 nn.Module,以帮助完成第一步。

请记住,模型最初是为接受输入数据点批次而编写的。在我们的 CNN 示例中,没有批次间操作。也就是说,批次中的每个数据点都独立于其他数据点。基于此假设,我们可以轻松生成一个在单个数据点上评估模型的函数

net = CNN().to(device)

# Detaching the parameters because we won't be calling Tensor.backward().
params = {k: v.detach() for k, v in net.named_parameters()}

def fnet_single(params, x):
    return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)

计算 NTK:方法 1 (雅可比收缩)#

我们已准备好计算经验 NTK。两个数据点 \(x_1\)\(x_2\) 的经验 NTK 定义为在 \(x_1\) 处评估的模型雅可比与在 \(x_2\) 处评估的模型雅可比之间的矩阵乘积

\[J_{net}(x_1) J_{net}^T(x_2)\]

\(x_1\) 是数据点批次且 \(x_2\) 是数据点批次的批次情况下,我们想要 \(x_1\)\(x_2\) 中所有数据点组合的雅可比之间的矩阵乘积。

第一种方法包括执行此操作——计算两个雅可比,然后收缩它们。以下是在批次情况下计算 NTK 的方法

def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = jac1.values()
    jac1 = [j.flatten(2) for j in jac1]

    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = jac2.values()
    jac2 = [j.flatten(2) for j in jac2]

    # Compute J(x1) @ J(x2).T
    result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result

result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
print(result.shape)
torch.Size([20, 5, 10, 10])

在某些情况下,您可能只想要该数量的对角线或迹,尤其是在您事先知道网络架构会产生 NTK,其中非对角线元素可以近似为零的情况下。可以轻松地调整上述函数来执行此操作

def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = jac1.values()
    jac1 = [j.flatten(2) for j in jac1]

    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = jac2.values()
    jac2 = [j.flatten(2) for j in jac2]

    # Compute J(x1) @ J(x2).T
    einsum_expr = None
    if compute == 'full':
        einsum_expr = 'Naf,Mbf->NMab'
    elif compute == 'trace':
        einsum_expr = 'Naf,Maf->NM'
    elif compute == 'diagonal':
        einsum_expr = 'Naf,Maf->NMa'
    else:
        assert False

    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result

result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')
print(result.shape)
torch.Size([20, 5])

此方法的渐近时间复杂度为 \(N O [FP]\) (计算雅可比的时间) + \(N^2 O^2 P\) (收缩雅可比的时间),其中 \(N\)\(x_1\)\(x_2\) 的批次大小,\(O\) 是模型的输出大小,\(P\) 是参数总数,\([FP]\) 是通过模型进行单次前向传播的成本。有关详细信息,请参阅 Fast Finite Width Neural Tangent Kernel 的第 3.2 节。

计算 NTK:方法 2 (NTK-向量积)#

接下来我们将讨论一种使用 NTK-向量积计算 NTK 的方法。

此方法将 NTK 重构为一系列 NTK-向量积,这些积应用于大小为 \(O\times O\) 的单位矩阵 \(I_O\) 的列 (其中 \(O\) 是模型的输出大小)

\[J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \left[J_{net}(x_1) \left[J_{net}^T(x_2) e_o\right]\right]_{o=1}^{O},\]

其中 \(e_o\in \mathbb{R}^O\) 是单位矩阵 \(I_O\) 的列向量。

  • \(\textrm{vjp}_o = J_{net}^T(x_2) e_o\)。我们可以使用向量-雅可比积来计算它。

  • 现在,考虑 \(J_{net}(x_1) \textrm{vjp}_o\)。这是一个雅可比-向量积!

  • 最后,我们可以使用 vmap 来并行运行上述计算,遍历 \(I_O\) 的所有列 \(e_o\)

这表明我们可以结合使用反向模式 AD (计算向量-雅可比积) 和前向模式 AD (计算雅可比-向量积) 来计算 NTK。

让我们来实现它

def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
    def get_ntk(x1, x2):
        def func_x1(params):
            return func(params, x1)

        def func_x2(params):
            return func(params, x2)

        output, vjp_fn = vjp(func_x1, params)

        def get_ntk_slice(vec):
            # This computes ``vec @ J(x2).T``
            # `vec` is some unit vector (a single slice of the Identity matrix)
            vjps = vjp_fn(vec)
            # This computes ``J(X1) @ vjps``
            _, jvps = jvp(func_x2, (params,), vjps)
            return jvps

        # Here's our identity matrix
        basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
        return vmap(get_ntk_slice)(basis)

    # ``get_ntk(x1, x2)`` computes the NTK for a single data point x1, x2
    # Since the x1, x2 inputs to ``empirical_ntk_ntk_vps`` are batched,
    # we actually wish to compute the NTK between every pair of data points
    # between {x1} and {x2}. That's what the ``vmaps`` here do.
    result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)

    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK', result)

# Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy
with torch.backends.cudnn.flags(allow_tf32=False):
    result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
    result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)

assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)
/usr/local/lib/python3.10/dist-packages/torch/backends/cudnn/__init__.py:145: UserWarning:

Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.ac.cn/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)

我们为 empirical_ntk_ntk_vps 编写的代码看起来直接翻译自上面的数学!这展示了函数变换的强大之处:如果您只使用 torch.autograd.grad,很难编写一个高效的版本。

此方法的渐近时间复杂度为 \(N^2 O [FP]\),其中 \(N\)\(x_1\)\(x_2\) 的批次大小,\(O\) 是模型的输出大小,\([FP]\) 是通过模型进行单次前向传播的成本。因此,此方法比方法 1 (雅可比收缩) 执行更多的网络前向传播 ( \(N^2 O\) 而不是 \(N O\) ),但完全避免了收缩成本 (没有 \(N^2 O^2 P\) 项,其中 \(P\) 是模型参数总数)。因此,当 \(O P\) 相对于 \([FP]\) 很大时,此方法更可取,例如具有许多输出 \(O\) 的全连接 (非卷积) 模型。内存方面,两种方法应该相当。有关详细信息,请参阅 Fast Finite Width Neural Tangent Kernel 的第 3.3 节。

脚本总运行时间: (0 分钟 0.825 秒)