评价此页

UX 限制#

创建日期:2025 年 6 月 12 日 | 最后更新日期:2025 年 6 月 12 日

torch.func 与 JAX 类似,在可转换的内容方面存在限制。通常来说,JAX 的限制是转换只适用于纯函数,即输出完全由输入决定且不涉及副作用(如变异)的函数。

我们有类似的保证:我们的转换对纯函数效果很好。但是,我们也支持某些原地(in-place)操作。一方面,编写与函数转换兼容的代码可能需要改变你编写 PyTorch 代码的方式;另一方面,你可能会发现我们的转换能让你表达以前在 PyTorch 中难以表达的内容。

通用限制#

所有 torch.func 转换都存在一个共同的限制,即函数不应向全局变量赋值。相反,函数的所有输出都必须从函数返回。这个限制源于 torch.func 的实现方式:每个转换都会将 Tensor 输入封装在特殊的 torch.func Tensor 子类中,以方便转换。

因此,请勿如下操作:

import torch
from torch.func import grad

# Don't do this
intermediate = None

def f(x):
  global intermediate
  intermediate = x.sin()
  z = intermediate.sin()
  return z

x = torch.randn([])
grad_x = grad(f)(x)

请将 f 重写为返回 intermediate

def f(x):
  intermediate = x.sin()
  z = intermediate.sin()
  return z, intermediate

grad_x, intermediate = grad(f, has_aux=True)(x)

torch.autograd API#

如果你尝试在由 vmap() 或 torch.func 的 AD 转换(vjp()jvp()jacrev()jacfwd())进行转换的函数内使用 torch.autograd API,如 torch.autograd.gradtorch.autograd.backward,转换可能无法对其进行转换。如果它无法做到,你将收到一条错误消息。

这是 PyTorch AD 支持实现方式上的根本性设计限制,也是我们设计 torch.func 库的原因。请改用 torch.autograd API 的 torch.func 等效项:

  • torch.autograd.gradTensor.backward -> torch.func.vjptorch.func.grad

  • torch.autograd.functional.jvp -> torch.func.jvp

  • torch.autograd.functional.jacobian -> torch.func.jacrevtorch.func.jacfwd

  • torch.autograd.functional.hessian -> torch.func.hessian

vmap 限制#

注意

vmap() 是我们限制最多的转换。与 grad 相关的转换(grad()vjp()jvp())没有这些限制。jacfwd()(以及 hessian(),它使用 jacfwd() 实现)是 vmap()jvp() 的组合,因此它也存在这些限制。

vmap(func) 是一个转换,它返回一个函数,该函数将 func 映射到每个输入 Tensor 的某个新维度上。vmap 的思想是它类似于运行一个 for 循环:对于纯函数(即没有副作用的情况下),vmap(f)(x) 等价于

torch.stack([f(x_i) for x_i in x.unbind(0)])

变异:任意变异 Python 数据结构#

在存在副作用的情况下,vmap() 不再像运行 for 循环那样工作。例如,下面的函数

def f(x, list):
  list.pop()
  print("hello!")
  return x.sum(0)

x = torch.randn(3, 1)
lst = [0, 1, 2, 3]

result = vmap(f, in_dims=(0, None))(x, lst)

将打印“hello!”一次,并只从 lst 中 pop 一个元素。

vmap() 只会执行 f 一次,所以所有的副作用也只发生一次。

这是 vmap 实现方式的结果。torch.func 有一个特殊的内部 BatchedTensor 类。vmap(f)(*inputs) 获取所有 Tensor 输入,将它们转换为 BatchedTensors,然后调用 f(*batched_tensor_inputs)。BatchedTensor 重载了 PyTorch API,为每个 PyTorch 运算符产生批处理(即向量化)的行为。

变异:原地 PyTorch 操作#

你可能在这里遇到关于 vmap 不兼容的原地操作的错误。如果 vmap() 遇到不支持的原地操作,它将引发错误,否则将成功。不支持的操作是指那些会导致写入的 Tensor 元素数量多于被写 Tensor 元素数量的操作。以下是如何发生这种情况的示例:

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(1)
y = torch.randn(3, 1)  # When vmapped over, looks like it has shape [1]

# Raises an error because `x` has fewer elements than `y`.
vmap(f, in_dims=(None, 0))(x, y)

x 是一个包含一个元素的 Tensor,y 是一个包含三个元素的 Tensor。x + y 包含三个元素(由于广播),但是尝试将三个元素写回 x,而 x 只包含一个元素,这会因为尝试将三个元素写入一个只有一个元素的 Tensor 而引发错误。

如果被写入的 Tensor 在 vmap() 下是批处理的(即它正在被 vmap over),则没有问题。

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(3, 1)
y = torch.randn(3, 1)
expected = x + y

# Does not raise an error because x is being vmapped over.
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)

一个常见的解决方法是使用其“new_*”等效项替换工厂函数的调用。例如:

为了说明原因,请看以下内容。

def diag_embed(vec):
  assert vec.dim() == 1
  result = torch.zeros(vec.shape[0], vec.shape[0])
  result.diagonal().copy_(vec)
  return result

vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])

# RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ...
vmap(diag_embed)(vecs)

vmap() 内部,result 是一个形状为 [3, 3] 的 Tensor。但是,尽管 vec 的形状看起来是 [3],但 vec 的实际底层形状是 [2, 3]。将 vec 复制到 result.diagonal()(其形状为 [3])是不可能的,因为 vec 包含的元素太多。

def diag_embed(vec):
  assert vec.dim() == 1
  result = vec.new_zeros(vec.shape[0], vec.shape[0])
  result.diagonal().copy_(vec)
  return result

vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
vmap(diag_embed)(vecs)

torch.zeros() 替换为 Tensor.new_zeros() 后,result 的底层 Tensor 的形状为 [2, 3, 3],因此现在可以将形状为 [2, 3] 的 vec 复制到 result.diagonal() 中。

Mutation: out= PyTorch Operations#

vmap() 不支持 PyTorch 操作中的 out= 关键字参数。如果在代码中遇到此参数,它将优雅地报错。

这不是根本性的限制;理论上我们将来可以支持此功能,但目前我们选择不这样做。

Data-dependent Python control flow#

我们尚不支持在依赖数据的控制流上使用 vmap。依赖数据的控制流是指 if 语句、while 循环或 for 循环的条件是一个正在被 vmap 的 Tensor。例如,以下代码将引发错误消息:

def relu(x):
  if x > 0:
    return x
  return 0

x = torch.randn(3)
vmap(relu)(x)

但是,任何不依赖于 vmap 化的 Tensor 中值的控制流都可以正常工作。

def custom_dot(x):
  if x.dim() == 1:
    return torch.dot(x, x)
  return (x * x).sum()

x = torch.randn(3)
vmap(custom_dot)(x)

JAX 支持使用特殊的控制流运算符(例如 jax.lax.condjax.lax.while_loop)在 依赖数据的控制流 上进行转换。我们正在研究为 PyTorch 添加等效功能。

Data-dependent operations (.item())#

我们不(也不会)支持对调用了 Tensor 的 .item() 的用户定义函数使用 vmap。例如,以下代码将引发错误消息:

def f(x):
  return x.item()

x = torch.randn(3)
vmap(f)(x)

请尝试重写您的代码,避免使用 .item() 调用。

您也可能遇到有关使用 .item() 的错误消息,但您可能并未实际使用它。在这种情况下,PyTorch 内部可能正在调用 .item() – 请在 GitHub 上提交一个 issue,我们将修复 PyTorch 内部问题。

Dynamic shape operations (nonzero and friends)#

vmap(f) 要求 f 应用于输入的每个“示例”时返回的 Tensor 形状相同。像 torch.nonzerotorch.is_nonzero 这样的操作不受支持,并将因此报错。

为了理解原因,请考虑以下示例:

xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)

torch.nonzero(xs[0]) 返回形状为 2 的 Tensor;而 torch.nonzero(xs[1]) 返回形状为 1 的 Tensor。我们无法构造一个单一的输出 Tensor;输出需要是一个 ragged Tensor(而 PyTorch 尚未实现 ragged Tensor 的概念)。

Randomness#

用户在调用随机操作时可能意图不明。具体来说,一些用户可能希望随机行为在批次之间保持一致,而另一些用户则希望其在批次之间有所不同。为了解决这个问题,vmap 采用了一个随机性标志。

该标志只能传递给 vmap,并且可以取三个值:“error”、“different”或“same”,默认为 error。在“error”模式下,任何对随机函数的调用都会产生一个错误,要求用户根据其用例使用另外两个标志之一。

在“different”随机模式下,批次中的元素会产生不同的随机值。例如:

def add_noise(x):
  y = torch.randn(())  # y will be different across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="different")(x)  # we get 3 different values

在“same”随机模式下,批次中的元素产生相同的随机值。例如:

def add_noise(x):
  y = torch.randn(())  # y will be the same across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="same")(x)  # we get the same value, repeated 3 times

警告

我们的系统只能确定 PyTorch 操作的随机性行为,而无法控制 numpy 等其他库的行为。这与 JAX 解决方案中的限制类似。

注意

使用任一类型的受支持随机性的多个 vmap 调用将不会产生相同的结果。与标准 PyTorch 一样,用户可以通过在 vmap 外部使用 torch.manual_seed() 或使用生成器来获得随机性可复现性。

注意

最后,我们的随机性与 JAX 不同,因为我们不使用无状态 PRNG,部分原因是 PyTorch 对无状态 PRNG 没有完全支持。相反,我们引入了一个标志系统,以支持我们遇到的最常见的随机性形式。如果您的用例不符合这些随机性形式,请提交一个 issue。