torch.func#
创建日期:2025 年 6 月 11 日 | 最后更新日期:2025 年 6 月 11 日
torch.func,之前称为“functorch”,是 JAX 式的可组合函数变换,适用于 PyTorch。
注意
该库目前处于 Beta 阶段。这意味着这些功能通常都能正常工作(除非另有说明),并且我们(PyTorch 团队)致力于推进该库的发展。然而,API 可能会根据用户反馈而更改,并且我们无法完全覆盖所有 PyTorch 操作。
如果您对 API 或您希望覆盖的使用场景有任何建议,请在 GitHub 上提交 issue 或与我们联系。我们很乐意了解您如何使用该库。
什么是可组合函数变换?#
“函数变换”是一种高阶函数,它接受一个数值函数并返回一个计算不同量的新函数。
torch.func
提供了自动微分变换(grad(f)
返回一个计算f
的梯度的函数)、矢量化/批量化变换(vmap(f)
返回一个在输入批次上计算f
的函数)等。这些函数变换可以任意组合。例如,组合
vmap(grad(f))
可以计算一个称为“每样本梯度”的量,这是标准 PyTorch 目前无法高效计算的。