注意
转到末尾 下载完整的示例代码。
torch.vmap#
本教程介绍 torch.vmap,这是一个用于 PyTorch 操作的自动向量化器。torch.vmap 是一个原型功能,无法处理许多用例;但是,我们希望收集其用例以指导设计。如果您正在考虑使用 torch.vmap 或认为它对某些事情非常有用,请通过 pytorch/pytorch#42368 联系我们。
那么,vmap 是什么?#
vmap 是一个高阶函数。它接受一个函数 func 并返回一个新函数,该函数将 func 映射到输入的一些维度上。它深受 JAX 的 vmap 启发。
从语义上讲,vmap 将“映射”推入 func 调用的 PyTorch 操作中,从而有效地向量化这些操作。
# NB: vmap is only available on nightly builds of PyTorch.
# You can download one at pytorch.org if you're interested in testing it out.
vmap 的第一个用例是使处理代码中的批处理维度变得更容易。可以编写一个在示例上运行的函数 func,然后使用 vmap(func) 将其提升为可以处理批处理示例的函数。但是 func 受到许多限制
它必须是纯函数(不能在其中改变 Python 数据结构),除了就地 PyTorch 操作。
示例批次必须作为张量提供。这意味着 vmap 不能直接处理可变长度序列。
使用 vmap 的一个例子是计算批处理点积。PyTorch 不提供批处理 torch.dot API;与其徒劳地翻阅文档,不如使用 vmap 构建一个新函数
vmap 可以帮助隐藏批处理维度,从而简化模型编写体验。
# Note that model doesn't work with a batch of feature vectors because
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
# to use `torch.matmul` instead, but if we didn't want to do that or if
# the code is more complicated (e.g., does some advanced indexing
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
# inputs, unless otherwise specified (with the in_dims argument,
# please see the documentation for more details).
vmap 还可以帮助向量化以前难以或不可能批处理的计算。这引出了我们的第二个用例:批处理梯度计算。
PyTorch 自动微分引擎计算 vjps(向量-雅可比积)。使用 vmap,我们可以计算(批处理向量)-雅可比积。
一个例子是计算完整的雅可比矩阵(这也可以应用于计算完整的 Hessian 矩阵)。对于某个函数 f: R^N -> R^N,计算完整的雅可比矩阵通常需要对 autograd.grad 进行 N 次调用,每个雅可比行一次。
# Setup
# Sequential approach
# Using `vmap`, we can vectorize the whole computation, computing the
# Jacobian in a single call to `autograd.grad`.
vmap 的第三个主要用例是计算逐样本梯度。这是 vmap 原型目前无法高性能处理的事情。我们不确定计算逐样本梯度的 API 应该是什么,但如果您有想法,请在 pytorch/pytorch#7786 中发表评论。
# The following doesn't actually work in the vmap prototype. But it
# could be an API for computing per-sample-gradients.
# batch_of_samples = torch.randn(64, 5)
# vmap(grad_sample)(batch_of_samples)
# %%%%%%RUNNABLE_CODE_REMOVED%%%%%%
脚本总运行时间:(0 分 0.002 秒)