torch.func#
创建日期:2025 年 6 月 11 日 | 最后更新日期:2025 年 6 月 11 日
torch.func,以前称为“functorch”,是 PyTorch 中 JAX-like 可组合函数变换。
注意
该库目前处于 beta 阶段。这意味着功能通常可以正常工作(除非另有说明),并且我们(PyTorch 团队)致力于推动该库向前发展。但是,API 可能会根据用户反馈而改变,并且我们没有完全覆盖 PyTorch 操作。
如果您对 API 或需要涵盖的用例有任何建议,请提交 GitHub Issue 或联系我们。我们很乐意听取您如何使用该库的经验。
什么是可组合函数变换?#
“函数变换”是一种高阶函数,它接受一个数值函数并返回一个计算不同量的新函数。
torch.func
具有自动微分变换(grad(f)
返回一个计算f
梯度的函数)、向量化/批处理变换(vmap(f)
返回一个对输入批次计算f
的函数)等等。这些函数变换可以任意组合。例如,组合
vmap(grad(f))
可以计算一个称为逐样本梯度(per-sample-gradients)的量,而当前的 PyTorch 无法高效计算该量。