评价此页

torch.vmap#

本教程介绍 torch.vmap,一个 PyTorch 操作的自动向量化器。torch.vmap 是一个原型功能,无法处理许多用例;但是,我们希望收集其用例以指导设计。如果您考虑使用 torch.vmap 或认为它对某些方面非常有用,请通过 pytorch/pytorch#42368 联系我们。

那么,vmap 是什么?#

vmap 是一个高阶函数。它接受一个函数 func 并返回一个在新函数,该函数将 func 映射到输入的某些维度上。它深受 JAX 的 vmap 的启发。

从语义上讲,vmap 将“映射”推入 func 所调用的 PyTorch 操作中,从而有效地向量化了这些操作。

# NB: vmap is only available on nightly builds of PyTorch.
# You can download one at pytorch.org if you're interested in testing it out.

vmap 的第一个用例是简化代码中处理批次维度。可以编写一个对单个示例运行的函数 func,然后使用 vmap(func) 将其提升为一个可以接受示例批次的函数。但是,func 受制于许多限制:

  • 它必须是函数式的(不能在其内部修改 Python 数据结构),但允许使用原地 PyTorch 操作。

  • 示例批次必须以 Tensor 的形式提供。这意味着 vmap 不能开箱即用地处理可变长度序列。

使用 vmap 的一个例子是计算批次点积。PyTorch 没有提供批次的 torch.dot API;与其徒劳地翻阅文档,不如使用 vmap 来构建一个新函数。

vmap 有助于隐藏批次维度,从而带来更简单的模型编写体验。

# Note that model doesn't work with a batch of feature vectors because
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
# to use `torch.matmul` instead, but if we didn't want to do that or if
# the code is more complicated (e.g., does some advanced indexing
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
# inputs, unless otherwise specified (with the in_dims argument,
# please see the documentation for more details).

vmap 还可以帮助向量化以前难以或不可能批处理的计算。这就引出了我们的第二个用例:批次梯度计算。

PyTorch 的自动微分引擎计算 vjps(向量-雅可比乘积)。使用 vmap,我们可以计算(批次向量)-雅可比乘积。

这方面的一个例子是计算完整的雅可比矩阵(也可用于计算完整的Hessian矩阵)。计算某个函数 f: R^N -> R^N 的完整雅可比矩阵通常需要调用 autograd.grad N 次,每次对应一个雅可比行。

# Setup








# Sequential approach




# Using `vmap`, we can vectorize the whole computation, computing the
# Jacobian in a single call to `autograd.grad`.

vmap 的第三个主要用例是计算每个样本的梯度。这是 vmap 原型目前无法高效处理的情况。我们不确定计算每个样本梯度应该使用什么样的 API,但如果您有想法,请在 pytorch/pytorch#7786 中评论。

# The following doesn't actually work in the vmap prototype. But it
# could be an API for computing per-sample-gradients.

# batch_of_samples = torch.randn(64, 5)
# vmap(grad_sample)(batch_of_samples)

# %%%%%%RUNNABLE_CODE_REMOVED%%%%%%

脚本总运行时间:(0 分 0.002 秒)