torch.func#
创建日期:2025 年 6 月 11 日 | 最后更新日期:2025 年 6 月 11 日
torch.func(原名 “functorch”)是 PyTorch 中 JAX 风格的可组合函数变换库。
注意
该库目前处于 Beta(测试)阶段。这意味着这些功能通常可以使用(除非另有说明),并且我们(PyTorch 团队)致力于进一步完善该库。不过,API 可能会根据用户反馈进行调整,且我们尚未完全覆盖所有的 PyTorch 操作。
如果您对 API 有任何建议,或希望我们涵盖某些用例,请在 GitHub 上提交 issue 或与我们联系。我们非常期待了解您是如何使用该库的。
什么是可组合函数变换?#
“函数变换”是一种高阶函数,它接受一个数值函数并返回一个计算不同量的新函数。
torch.func包含自动微分变换(grad(f)返回一个计算f梯度的函数)、向量化/批处理变换(vmap(f)返回一个在输入批次上计算f的函数)以及其他变换。这些函数变换可以任意组合。例如,组合
vmap(grad(f))可以计算一个称为“每样本梯度”的量,这是标准 PyTorch 目前无法高效计算的。
为什么要使用可组合函数变换?#
目前在 PyTorch 中有一些难以实现的使用案例
计算逐样本梯度(或其他逐样本量)
在单机上运行模型集成
高效地批处理 MAML 内部循环中的任务
高效地计算雅可比矩阵和 Hessian 矩阵
高效地计算批处理雅可比矩阵和 Hessian 矩阵
通过组合 vmap()、grad() 和 vjp() 变换,我们无需为每种功能设计单独的子系统即可实现上述需求。这种可组合函数变换的思想源自 JAX 框架。