评价此页

torch.vmap#

torch.vmap(func, in_dims=0, out_dims=0, randomness='error', *, chunk_size=None)[源代码]#

vmap 是矢量化映射;vmap(func) 返回一个新函数,该函数将 func 映射到输入的某些维度上。从语义上讲,vmap 将 map 推入 func 调用的 PyTorch 操作中,从而有效地将这些操作向量化。

vmap 对于处理批处理维度很有用:您可以编写一个在单个示例上运行的函数 func,然后使用 vmap(func) 将其提升为一个可以接受示例批次的函数。vmap 还可以与 autograd 组合以计算批处理梯度。

注意

torch.vmap() 为了方便起见,被别名为 torch.func.vmap()。您可以随意使用其中任何一个。

参数
  • func (function) – 接受一个或多个参数的 Python 函数。必须返回一个或多个 Tensor。

  • in_dims (int嵌套结构) – 指定要映射输入的哪个维度。in_dims 的结构应与输入相匹配。如果某个输入的 in_dim 为 None,则表示没有映射维度。默认为 0。

  • out_dims (intTuple[int]) – 指定映射的维度应出现在输出的哪个位置。如果 out_dims 是一个元组,那么它应该为每个输出有一个元素。默认为 0。

  • randomness (str) – 指定此 vmap 中的随机性在批次之间是相同还是不同。如果为 'different',则每个批次的随机性将不同。如果为 'same',则批次之间的随机性将相同。如果为 'error',则任何对随机函数的调用都将出错。默认为 'error'。警告:此标志仅适用于 PyTorch 的随机操作,不适用于 Python 的 random 模块或 numpy 的随机性。

  • chunk_size (Noneint) – 如果为 None(默认),则在输入上应用单个 vmap。如果不是 None,则一次计算 chunk_size 个样本的 vmap。请注意,chunk_size=1 等同于使用 for 循环计算 vmap。如果您在计算 vmap 时遇到内存问题,请尝试使用非 None 的 chunk_size。

返回

返回一个新的“批处理”函数。它接受与 func 相同的输入,只是每个输入的指定 in_dims 索引处会多出一个维度。它返回与 func 相同的输出,只是每个输出的指定 out_dims 索引处会多出一个维度。

返回类型

Callable

使用 vmap() 的一个示例是计算批处理的内积。PyTorch 没有提供批处理的 torch.dot API;与其在文档中徒劳地搜索,不如使用 vmap() 构建一个新函数。

>>> torch.dot  # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.dot)  # [N, D], [N, D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)

vmap() 有助于隐藏批处理维度,从而简化模型编写体验。

>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>> # Very simple linear model with activation
>>>     return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = torch.vmap(model)(examples)

vmap() 还可以帮助向量化以前难以或不可能进行批处理的计算。一个例子是高阶梯度计算。PyTorch 的 autograd 引擎计算 vjps(向量-雅可比乘积)。计算某个函数 f: R^N -> R^N 的完整雅可比矩阵通常需要 N 次调用 autograd.grad,每次调用计算雅可比矩阵的一行。使用 vmap(),我们可以向量化整个计算,在一次调用 autograd.grad 中计算雅可比矩阵。

>>> # Setup
>>> N = 5
>>> f = lambda x: x**2
>>> x = torch.randn(N, requires_grad=True)
>>> y = f(x)
>>> I_N = torch.eye(N)
>>>
>>> # Sequential approach
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
>>>                  for v in I_N.unbind()]
>>> jacobian = torch.stack(jacobian_rows)
>>>
>>> # vectorized gradient computation
>>> def get_vjp(v):
>>>     return torch.autograd.grad(y, x, v)
>>> jacobian = torch.vmap(get_vjp)(I_N)

vmap() 也可以嵌套,生成具有多个批处理维度的输出

>>> torch.dot  # [D], [D] -> []
>>> batched_dot = torch.vmap(
...     torch.vmap(torch.dot)
... )  # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
>>> batched_dot(x, y)  # tensor of size [2, 3]

如果输入不在第一个维度上进行批处理,in_dims 会指定每个输入的批处理维度为

>>> torch.dot  # [N], [N] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=1)  # [N, D], [N, D] -> [D]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(
...     x, y
... )  # output is [5] instead of [2] if batched along the 0th dimension

如果存在多个输入,且每个输入在不同维度上进行批处理,in_dims 必须是一个元组,其中包含每个输入的批处理维度为

>>> torch.dot  # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None))  # [N, D], [D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> batched_dot(
...     x, y
... )  # second arg doesn't have a batch dim because in_dim[1] was None

如果输入是 Python 结构,in_dims 必须是一个元组,其中包含一个与输入形状匹配的结构。

>>> f = lambda dict: torch.dot(dict["x"], dict["y"])
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> input = {"x": x, "y": y}
>>> batched_dot = torch.vmap(f, in_dims=({"x": 0, "y": None},))
>>> batched_dot(input)

默认情况下,输出在第一个维度上进行批处理。但是,可以使用 out_dims 在任何维度上进行批处理。

>>> f = lambda x: x**2
>>> x = torch.randn(2, 5)
>>> batched_pow = torch.vmap(f, out_dims=1)
>>> batched_pow(x)  # [5, 2]

对于任何使用 kwargs 的函数,返回的函数不会批处理 kwargs,但会接受 kwargs。

>>> x = torch.randn([2, 5])
>>> def fn(x, scale=4.):
>>>   return x * scale
>>>
>>> batched_pow = torch.vmap(fn)
>>> assert torch.allclose(batched_pow(x), x * 4)
>>> batched_pow(x, scale=x)  # scale is not batched, output has shape [2, 2, 5]

注意

vmap 不提供开箱即用的通用自动批处理或处理可变长度序列。