注意
跳转至页面底部 下载完整示例代码。
Jacobians、Hessians、hvp、vhp 以及更多:组合函数变换#
创建日期:2023年3月15日 | 最后更新:2023年4月18日 | 最后验证:2024年11月5日
计算 Jacobian 或 Hessian 在许多非传统深度学习模型中非常有用。使用 PyTorch 的常规自动微分 API(Tensor.backward(), torch.autograd.grad)高效地计算这些量很困难(或很麻烦)。PyTorch 受 JAX 启发的 函数变换 API 提供了高效计算各种高阶自动微分量的方法。
注意
本教程需要 PyTorch 2.0.0 或更高版本。
计算 Jacobian#
import torch
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)
让我们从一个我们想要计算其 Jacobian 的函数开始。这是一个带有非线性激活的简单线性函数。
让我们添加一些虚拟数据:权重、偏置和一个特征向量 x。
D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D) # feature vector
我们将 predict 视为一个将输入 x 从 \(R^D \to R^D\) 映射的函数。PyTorch Autograd 计算向量-雅可比乘积 (vector-Jacobian products)。为了计算该 \(R^D \to R^D\) 函数的完整 Jacobian,我们必须通过每次使用不同的单位向量来逐行计算它。
def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors]
return torch.stack(jacobian_rows)
xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)
jacobian = compute_jac(xp)
print(jacobian.shape)
print(jacobian[0]) # show first row
torch.Size([16, 16])
tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,
0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])
与其逐行计算 Jacobian,我们可以使用 PyTorch 的 torch.vmap 函数变换来消除 for 循环并向量化计算。我们不能直接将 vmap 应用于 torch.autograd.grad;相反,PyTorch 提供了一个可以与 torch.vmap 组合的 torch.func.vjp 变换。
from torch.func import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
# let's confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)
在后续教程中,反向模式自动微分 (AD) 与 vmap 的组合将为我们提供逐样本梯度 (per-sample-gradients)。在本教程中,组合反向模式 AD 和 vmap 为我们提供了 Jacobian 计算!vmap 和自动微分变换的各种组合可以得到各种有趣的量。
PyTorch 提供了 torch.func.jacrev 作为便捷函数,它执行 vmap-vjp 组合来计算 Jacobian。jacrev 接受一个 argnums 参数,指定我们希望针对哪个参数计算 Jacobian。
from torch.func import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
# Confirm by running the following:
assert torch.allclose(ft_jacobian, jacobian)
让我们比较一下两种计算 Jacobian 方法的性能。函数变换版本快得多(而且输出越多,速度提升越明显)。
通常,我们预期通过 vmap 进行向量化可以帮助消除开销,并更好地利用硬件资源。
vmap 通过将外层循环推入函数的原始操作中来实现这种魔法,从而获得更好的性能。
让我们编写一个简单的函数来评估性能,并处理微秒和毫秒的测量值。
def get_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
faster = second.times[0]
slower = first.times[0]
gain = (slower-faster)/slower
if gain < 0: gain *=-1
final_gain = gain*100
print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
然后运行性能比较。
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)
print(no_vmap_timer)
print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7f7a710bc8b0>
compute_jac(xp)
1.48 ms
1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f7ab4bb7c10>
jacrev(predict, argnums=2)(weight, bias, x)
430.48 us
1 measurement, 500 runs , 1 thread
让我们使用我们的 get_perf 函数对上述内容进行相对性能比较。
get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap")
Performance delta: 70.9757 percent improvement with vmap
此外,反转问题并计算模型参数(权重、偏置)而不是输入相对于 Jacobian 是非常容易的。
# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
反向模式 Jacobian (jacrev) 与前向模式 Jacobian (jacfwd)#
我们提供两种 API 来计算 Jacobian:jacrev 和 jacfwd。
jacrev使用反向模式 AD。如上所示,它是我们vjp和vmap变换的组合。jacfwd使用前向模式 AD。它被实现为我们jvp和vmap变换的组合。
jacfwd 和 jacrev 可以相互替代,但它们具有不同的性能特征。
经验法则:如果你正在计算 \(R^N \to R^M\) 函数的 Jacobian,且输出数量远多于输入数量(例如 \(M > N\)),则首选 jacfwd,否则使用 jacrev。此规则存在例外,但其背后的非严谨论据如下。
在反向模式 AD 中,我们逐行计算 Jacobian;而在前向模式 AD 中(计算 Jacobian-vector products),我们逐列计算它。Jacobian 矩阵有 M 行 N 列,因此如果它是长矩阵或宽矩阵,我们可能更倾向于处理行数或列数较少的方法。
首先,让我们在输入多于输出的情况下进行基准测试。
Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
# remember the general rule about taller vs wider... here we have a taller matrix:
print(weight.shape)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f7ab058fe20>
jacfwd(predict, argnums=2)(weight, bias, x)
790.12 us
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f7a710bc4f0>
jacrev(predict, argnums=2)(weight, bias, x)
8.42 ms
1 measurement, 500 runs , 1 thread
然后进行相对基准测试。
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
Performance delta: 965.2553 percent improvement with jacrev
现在是反过来的情况——输出 (M) 多于输入 (N)。
Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f7aafdcaf20>
jacfwd(predict, argnums=2)(weight, bias, x)
6.89 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f7ab4bb53c0>
jacrev(predict, argnums=2)(weight, bias, x)
510.92 us
1 measurement, 500 runs , 1 thread
以及相对性能比较。
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
Performance delta: 1248.5817 percent improvement with jacfwd
使用 functorch.hessian 计算 Hessian#
我们提供了一个计算 Hessian 的便捷 API:torch.func.hessian。Hessian 是 Jacobian 的 Jacobian(或偏导数的偏导数,即二阶导数)。
这表明人们可以简单地组合 functorch Jacobian 变换来计算 Hessian。事实上,在底层,hessian(f) 仅仅是 jacfwd(jacrev(f))。
注意:为了提高性能,根据你的模型,你可能希望使用 jacfwd(jacfwd(f)) 或 jacrev(jacrev(f)) 来利用上述关于宽矩阵与长矩阵的经验法则。
from torch.func import hessian
# lets reduce the size in order not to overwhelm Colab. Hessians require
# significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
让我们验证一下,无论使用 Hessian API 还是 jacfwd(jacfwd()),结果是否一致。
True
批处理 Jacobian 和批处理 Hessian#
在上面的例子中,我们一直使用单个特征向量进行操作。在某些情况下,你可能需要计算一批输出相对于一批输入的 Jacobian。也就是说,给定形状为 (B, N) 的输入批次和从 \(R^N \to R^M\) 的函数,我们想要一个形状为 (B, M, N) 的 Jacobian。
实现这一点的最简单方法是使用 vmap。
batch_size = 64
Din = 31
Dout = 33
weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
weight shape = torch.Size([33, 31])
如果你有一个从 (B, N) -> (B, M) 的函数,并且确定每个输入产生独立的输出,那么有时也可以不使用 vmap,通过对输出求和然后计算该函数的 Jacobian 来实现。
def predict_with_output_summed(weight, bias, x):
return predict(weight, bias, x).sum(0)
batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)
如果你有一个从 \(R^N \to R^M\) 的函数,但输入是批处理的,你可以将 vmap 与 jacrev 组合来计算批处理 Jacobian。
最后,批处理 Hessian 的计算方法类似。最简单的思考方式是使用 vmap 来对 Hessian 计算进行批处理,但在某些情况下,求和技巧也适用。
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))
batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape
torch.Size([64, 33, 31, 31])
计算 Hessian-vector 乘积#
计算 Hessian-vector 乘积 (hvp) 的原始方法是实例化完整的 Hessian 并执行与向量的点积。我们可以做得更好:事实证明,我们不需要实例化完整的 Hessian 即可做到这一点。我们将介绍计算 hvp 的两种(众多策略中的)策略:- 组合反向模式 AD 与反向模式 AD - 组合反向模式 AD 与前向模式 AD
组合反向模式 AD 与前向模式 AD(相对于反向模式与反向模式)通常是计算 hvp 更内存高效的方式,因为前向模式 AD 不需要构建 Autograd 图并保存用于反向传播的中间变量。
以下是一些示例用法。
def f(x):
return x.sin().sum()
x = torch.randn(2048)
tangent = torch.randn(2048)
result = hvp(f, (x,), (tangent,))
如果 PyTorch 前向 AD 未覆盖你的操作,那么我们可以改为组合反向模式 AD 与反向模式 AD。
脚本总运行时间: (0 分 10.568 秒)