评价此页

torch.vmap#

本教程介绍 torch.vmap,这是一个用于 PyTorch 操作的自动向量化器。torch.vmap 是一个原型功能,无法处理许多用例;但是,我们希望收集其用例以指导设计。如果您正在考虑使用 torch.vmap 或认为它对某些事情非常有用,请通过 pytorch/pytorch#42368 联系我们。

那么,vmap 是什么?#

vmap 是一个高阶函数。它接受一个函数 func 并返回一个新函数,该函数将 func 映射到输入的一些维度上。它深受 JAX 的 vmap 启发。

从语义上讲,vmap 将“映射”推入 func 调用的 PyTorch 操作中,从而有效地向量化这些操作。

# NB: vmap is only available on nightly builds of PyTorch.
# You can download one at pytorch.org if you're interested in testing it out.

vmap 的第一个用例是使处理代码中的批处理维度变得更容易。可以编写一个在示例上运行的函数 func,然后使用 vmap(func) 将其提升为可以处理批处理示例的函数。但是 func 受到许多限制

  • 它必须是纯函数(不能在其中改变 Python 数据结构),除了就地 PyTorch 操作。

  • 示例批次必须作为张量提供。这意味着 vmap 不能直接处理可变长度序列。

使用 vmap 的一个例子是计算批处理点积。PyTorch 不提供批处理 torch.dot API;与其徒劳地翻阅文档,不如使用 vmap 构建一个新函数

vmap 可以帮助隐藏批处理维度,从而简化模型编写体验。

# Note that model doesn't work with a batch of feature vectors because
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
# to use `torch.matmul` instead, but if we didn't want to do that or if
# the code is more complicated (e.g., does some advanced indexing
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
# inputs, unless otherwise specified (with the in_dims argument,
# please see the documentation for more details).

vmap 还可以帮助向量化以前难以或不可能批处理的计算。这引出了我们的第二个用例:批处理梯度计算。

PyTorch 自动微分引擎计算 vjps(向量-雅可比积)。使用 vmap,我们可以计算(批处理向量)-雅可比积。

一个例子是计算完整的雅可比矩阵(这也可以应用于计算完整的 Hessian 矩阵)。对于某个函数 f: R^N -> R^N,计算完整的雅可比矩阵通常需要对 autograd.grad 进行 N 次调用,每个雅可比行一次。

# Setup








# Sequential approach




# Using `vmap`, we can vectorize the whole computation, computing the
# Jacobian in a single call to `autograd.grad`.

vmap 的第三个主要用例是计算逐样本梯度。这是 vmap 原型目前无法高性能处理的事情。我们不确定计算逐样本梯度的 API 应该是什么,但如果您有想法,请在 pytorch/pytorch#7786 中发表评论。

# The following doesn't actually work in the vmap prototype. But it
# could be an API for computing per-sample-gradients.

# batch_of_samples = torch.randn(64, 5)
# vmap(grad_sample)(batch_of_samples)

# %%%%%%RUNNABLE_CODE_REMOVED%%%%%%

脚本总运行时间:(0 分 0.002 秒)