从 functorch 迁移到 torch.func#
创建日期:2025 年 6 月 11 日 | 最后更新日期:2025 年 6 月 11 日
torch.func,之前称为“functorch”,是 PyTorch 的 类 JAX 的可组合函数变换。
functorch 最初是 `pytorch/functorch` 仓库中的一个独立库。我们的目标一直是将 functorch 直接合并到 PyTorch 中,并将其作为核心 PyTorch 库提供。
作为合并的最后一步,我们决定从一个顶级包(`functorch`)迁移到 PyTorch 的一部分,以反映函数变换如何直接集成到 PyTorch 核心中。从 PyTorch 2.0 开始,我们弃用 `import functorch`,并要求用户迁移到我们将继续维护的最新 API。`import functorch` 将保留几期以维持向后兼容性。
函数变换#
以下 API 是以下 functorch API 的直接替换。它们完全向后兼容。
functorch API |
PyTorch API(截至 PyTorch 2.0) |
---|---|
functorch.vmap |
|
functorch.grad |
|
functorch.vjp |
|
functorch.jvp |
|
functorch.jacrev |
|
functorch.jacfwd |
|
functorch.hessian |
|
functorch.functionalize |
此外,如果您使用的是 torch.autograd.functional API,请尝试使用 torch.func
的等效 API。在许多情况下,torch.func
的函数变换更具可组合性,性能也更好。
torch.autograd.functional API |
torch.func API(截至 PyTorch 2.0) |
---|---|
NN 模块实用工具#
我们更改了 API,以将函数变换应用于 NN 模块,使其更符合 PyTorch 的设计理念。新 API 不同,因此请仔细阅读本节。
functorch.make_functional#
torch.func.functional_call()
是 functorch.make_functional 和 functorch.make_functional_with_buffers 的替代品。但它不是直接替换。
如果您急需,可以使用 此 gist 中的辅助函数 来模拟 functorch.make_functional 和 functorch.make_functional_with_buffers 的行为。我们建议直接使用 torch.func.functional_call()
,因为它是一个更明确、更灵活的 API。
具体来说,functorch.make_functional 返回一个函数式模块和参数。函数式模块接受参数、模型输入作为参数。torch.func.functional_call()
允许使用新的参数、缓冲区和输入调用现有模块的前向传递。
这里有一个例子,说明如何使用 functorch 与 torch.func
计算模型参数的梯度。
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
def compute_loss(params, inputs, targets):
prediction = fmodel(params, inputs)
return torch.nn.functional.mse_loss(prediction, targets)
grads = functorch.grad(compute_loss)(params, inputs, targets)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
def compute_loss(params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)
grads = torch.func.grad(compute_loss)(params, inputs, targets)
这里有一个计算模型参数雅可比矩阵的例子。
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
请注意,为了节约内存,您应该只保留参数的单个副本。model.named_parameters()
不会复制参数。如果在模型训练中原地更新模型的参数,那么您的模型 `nn.Module` 拥有参数的单个副本,一切都正常。
但是,如果您想将参数保存在一个字典中并进行非原地更新,那么就会存在两个参数副本:字典中的一个,以及 `model` 中的一个。在这种情况下,您应该通过将 `model` 转换为元设备(`model.to('meta')`)来使其不持有内存。
functorch.combine_state_for_ensemble#
请使用 torch.func.stack_module_state()
来代替 functorch.combine_state_for_ensemble。 torch.func.stack_module_state()
返回两个字典,一个包含堆叠的参数,另一个包含堆叠的缓冲区,然后这些可以与 torch.vmap()
和 torch.func.functional_call()
一起用于集合。
例如,这是一个关于如何对一个非常简单的模型进行集合的例子。
import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy
# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')
params, buffers = torch.func.stack_module_state(models)
# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
return torch.func.functional_call(base_model, (params, buffers), (data,))
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
functorch.compile#
我们不再支持 functorch.compile(也称为 AOTAutograd)作为 PyTorch 中编译的前端;我们已将 AOTAutograd 集成到 PyTorch 的编译流程中。如果您是用户,请改用 torch.compile()
。