评价此页

torch.func#

创建日期:2025 年 6 月 11 日 | 最后更新日期:2025 年 6 月 11 日

torch.func,以前称为“functorch”,是 PyTorch 中类似 JAX 的可组合函数变换。

注意

该库目前处于 beta 阶段。这意味着功能通常可以正常工作(除非另有说明),并且我们(PyTorch 团队)致力于推进该库的发展。然而,API 可能会根据用户反馈进行更改,并且我们并未完全覆盖所有 PyTorch 操作。

如果您对 API 或希望涵盖的使用场景有任何建议,请在 GitHub 上提交 issue 或联系我们。我们非常想了解您如何使用该库。

什么是可组合函数变换?#

  • “函数变换”是一种高阶函数,它接受一个数值函数并返回一个计算不同量的新函数。

  • torch.func 包含自动微分变换(grad(f) 返回一个计算 f 的梯度的函数)、向量化/批处理变换(vmap(f) 返回一个计算输入批次的 f 的函数)等。

  • 这些函数变换可以任意组合。例如,组合 vmap(grad(f)) 可以计算一个称为“每样本梯度”的量,这是标准 PyTorch 目前无法高效计算的。

为什么使用可组合函数变换?#

目前在 PyTorch 中有一些难以实现的使用案例

  • 计算逐样本梯度(或其他逐样本量)

  • 在单机上运行模型集成

  • 高效地批处理 MAML 内部循环中的任务

  • 高效地计算雅可比矩阵和 Hessian 矩阵

  • 高效地计算批处理雅可比矩阵和 Hessian 矩阵

组合 vmap()grad()vjp() 变换,使我们无需为每种变换设计单独的子系统即可表达上述功能。这种可组合函数变换的思想源自 JAX 框架

阅读更多#