注意
跳转至页面底部 下载完整示例代码。
逐样本梯度#
创建日期: 2023年3月15日 | 最后更新: 2025年7月30日 | 最后验证: 2024年11月5日
什么是逐样本梯度?#
逐样本梯度计算是指计算数据批次中每一个样本的梯度。这是差分隐私、元学习和优化研究中的一个有用量。
注意
本教程需要 PyTorch 2.0.0 或更高版本。
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
# Here's a simple CNN and loss function:
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def loss_fn(predictions, targets):
return F.nll_loss(predictions, targets)
让我们生成一批模拟数据,并假设我们正在使用 MNIST 数据集。模拟图像大小为 28x28,我们使用大小为 64 的小批量(minibatch)。
device = 'cuda'
num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)
targets = torch.randint(10, (64,), device=device)
在常规模型训练中,我们会将小批量数据输入模型,然后调用 .backward() 来计算梯度。这将生成整个小批量的“平均”梯度。
model = SimpleCNN().to(device=device)
predictions = model(data) # move the entire mini-batch through the model
loss = loss_fn(predictions, targets)
loss.backward() # back propagate the 'average' gradient of this mini-batch
与上述方法相比,逐样本梯度计算等同于:
对于数据中的每个个体样本,执行一次前向传播和反向传播,以获得单独的(逐样本)梯度。
def compute_grad(sample, target):
sample = sample.unsqueeze(0) # prepend batch dimension for processing
target = target.unsqueeze(0)
prediction = model(sample)
loss = loss_fn(prediction, target)
return torch.autograd.grad(loss, list(model.parameters()))
def compute_sample_grads(data, targets):
""" manually process each sample with per sample gradient """
sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
sample_grads = zip(*sample_grads)
sample_grads = [torch.stack(shards) for shards in sample_grads]
return sample_grads
per_sample_grads = compute_sample_grads(data, targets)
sample_grads[0] 是 model.conv1.weight 的逐样本梯度。model.conv1.weight.shape 为 [32, 1, 3, 3];请注意,批次中的每个样本都有一个对应的梯度,总计 64 个。
print(per_sample_grads[0].shape)
torch.Size([64, 32, 1, 3, 3])
逐样本梯度的高效实现方法:使用函数变换#
我们可以通过使用函数变换来高效地计算逐样本梯度。
torch.func 函数变换 API 提供了对函数的变换能力。我们的策略是定义一个计算损失的函数,然后应用变换来构建一个计算逐样本梯度的函数。
我们将使用 torch.func.functional_call 函数将 nn.Module 视为一个函数来处理。
首先,让我们将 model 中的状态提取到两个字典中:参数(parameters)和缓冲区(buffers)。我们将它们 detach,因为我们不会使用常规的 PyTorch 自动求导(例如 Tensor.backward() 或 torch.autograd.grad)。
from torch.func import functional_call, vmap, grad
params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}
接下来,让我们定义一个函数,给定单个输入(而非一批输入)来计算模型损失。重要的是该函数必须接受参数、输入和目标,因为我们将对它们进行变换。
注意:由于模型最初是为了处理批次而编写的,我们将使用 torch.unsqueeze 来增加一个批次维度。
def compute_loss(params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = functional_call(model, (params, buffers), (batch,))
loss = loss_fn(predictions, targets)
return loss
现在,让我们使用 grad 变换来创建一个新函数,用于计算相对于 compute_loss 第一个参数(即 params)的梯度。
ft_compute_grad = grad(compute_loss)
ft_compute_grad 函数计算单个(样本,目标)对的梯度。我们可以使用 vmap 来让它计算整个样本和目标批次的梯度。注意,in_dims=(None, None, 0, 0) 是因为我们希望在数据和目标的第 0 维上映射 ft_compute_grad,并对每个样本使用相同的 params 和缓冲区。
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
最后,让我们使用变换后的函数来计算逐样本梯度。
我们可以进行二次检查,确认使用 grad 和 vmap 得到的结果与手动逐个处理的结果一致。
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1.2e-1, rtol=1e-5)
简单说明:vmap 可以转换的函数类型存在一定限制。最适合转换的函数是纯函数:即输出仅由输入决定且没有副作用(例如变量修改)的函数。vmap 无法处理任意 Python 数据结构的修改,但可以处理许多 PyTorch 的原地(in-place)操作。
性能比较#
好奇 vmap 的性能表现如何吗?
目前,在较新的 GPU(如 A100 Ampere 架构)上可获得最佳结果,在此示例中我们观察到高达 25 倍的加速。以下是我们构建机器上的一些测试结果。
def get_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
second_res = second.times[0]
first_res = first.times[0]
gain = (first_res-second_res)/first_res
if gain < 0: gain *=-1
final_gain = gain*100
print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_sample_grads(data, targets)", globals=globals())
with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())
no_vmap_timing = without_vmap.timeit(100)
with_vmap_timing = with_vmap.timeit(100)
print(f'Per-sample-grads without vmap {no_vmap_timing}')
print(f'Per-sample-grads with vmap {with_vmap_timing}')
get_perf(with_vmap_timing, "vmap", no_vmap_timing, "no vmap")
Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fbf15e17310>
compute_sample_grads(data, targets)
63.69 ms
1 measurement, 100 runs , 1 thread
Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fbf16773250>
ft_compute_sample_grad(params, buffers, data, targets)
3.40 ms
1 measurement, 100 runs , 1 thread
Performance delta: 1775.3596 percent improvement with vmap
在 PyTorch 中还有其他优化方案(如 pytorch/opacus)可以计算逐样本梯度,它们的性能也优于朴素方法。但令人兴奋的是,结合 vmap 和 grad 可以带来显著的加速效果。
总的来说,使用 vmap 进行向量化通常比在 for 循环中运行函数更快,且可以媲美手动批处理。但也存在一些例外,例如如果我们尚未针对某个特定操作实现 vmap 规则,或者底层内核未针对旧硬件(GPU)进行优化。如果您遇到此类情况,请通过 GitHub issue 向我们反馈。
脚本运行总耗时: (0 分钟 7.736 秒)