评价此页

torch.vmap#

本教程介绍了 torch.vmap,这是一个用于 PyTorch 操作的自动向量化器。torch.vmap 目前是一个原型功能,无法处理许多用例;不过,我们希望收集相关的用例以指导设计。如果您考虑使用 torch.vmap 或认为它在某些方面会非常酷,请通过 pytorch/pytorch#42368 联系我们。

那么,什么是 vmap?#

vmap 是一个高阶函数。它接受一个函数 func,并返回一个新函数,该新函数将 func 映射到输入数据的某个维度上。它在很大程度上受到了 JAX vmap 的启发。

从语义上讲,vmap 将“映射(map)”操作推入到 func 所调用的 PyTorch 操作中,从而有效地使这些操作向量化。

import torch
# NB: vmap is only available on nightly builds of PyTorch.
# You can download one at pytorch.org if you're interested in testing it out.
from torch import vmap

vmap 的第一个用例是让代码更容易处理批次(batch)维度。用户可以编写一个在单个样本上运行的函数 func,然后通过 vmap(func) 将其提升为一个可以接收样本批次的函数。然而,func 受到许多限制:

  • 它必须是函数式的(不能在其内部修改 Python 数据结构),但原位(in-place)PyTorch 操作除外。

  • 样本批次必须以 Tensor 形式提供。这意味着 vmap 无法开箱即用地处理变长序列。

使用 vmap 的一个例子是计算批次点积。PyTorch 没有提供批次化的 torch.dot API;与其在文档中徒劳地搜寻,不如使用 vmap 来构建一个新函数。

torch.dot                            # [D], [D] -> []
batched_dot = torch.vmap(torch.dot)  # [N, D], [N, D] -> [N]
x, y = torch.randn(2, 5), torch.randn(2, 5)
batched_dot(x, y)
tensor([ 1.9793, -0.9994])

vmap 有助于隐藏批次维度,从而带来更简单的模型编写体验。

batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)

# Note that model doesn't work with a batch of feature vectors because
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
# to use `torch.matmul` instead, but if we didn't want to do that or if
# the code is more complicated (e.g., does some advanced indexing
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
# inputs, unless otherwise specified (with the in_dims argument,
# please see the documentation for more details).
def model(feature_vec):
    # Very simple linear model with activation
    return feature_vec.dot(weights).relu()

examples = torch.randn(batch_size, feature_size)
result = torch.vmap(model)(examples)
expected = torch.stack([model(example) for example in examples.unbind()])
assert torch.allclose(result, expected)

vmap 还可以帮助向量化那些此前难以或无法批次化的计算。这就引出了我们的第二个用例:批次梯度计算。

PyTorch 的 autograd 引擎计算 vjps(向量-雅可比乘积)。使用 vmap,我们可以计算(批次向量)- 雅可比乘积。

其中一个例子是计算完整的雅可比矩阵(这也可用于计算完整的 Hessian 矩阵)。对于某些函数 f: R^N -> R^N,计算完整的雅可比矩阵通常需要 N 次 autograd.grad 调用,每次调用对应雅可比矩阵的一行。

# Setup
N = 5
def f(x):
    return x ** 2

x = torch.randn(N, requires_grad=True)
y = f(x)
basis_vectors = torch.eye(N)

# Sequential approach
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
                 for v in basis_vectors.unbind()]
jacobian = torch.stack(jacobian_rows)

# Using `vmap`, we can vectorize the whole computation, computing the
# Jacobian in a single call to `autograd.grad`.
def get_vjp(v):
    return torch.autograd.grad(y, x, v)[0]

jacobian_vmap = vmap(get_vjp)(basis_vectors)
assert torch.allclose(jacobian_vmap, jacobian)

vmap 的第三个主要用例是计算每个样本的梯度(per-sample-gradients)。这是目前 vmap 原型无法高效处理的内容。我们还不确定计算每个样本梯度的 API 应该是怎样的,但如果您有想法,请在 pytorch/pytorch#7786 中发表评论。

def model(sample, weight):
    # do something...
    return torch.dot(sample, weight)

def grad_sample(sample):
    return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]

# The following doesn't actually work in the vmap prototype. But it
# could be an API for computing per-sample-gradients.

# batch_of_samples = torch.randn(64, 5)
# vmap(grad_sample)(batch_of_samples)

脚本总运行时间: (0 分钟 0.129 秒)