评价此页

torch.func API 参考#

创建日期:2025 年 6 月 11 日 | 最后更新日期:2025 年 6 月 11 日

函数变换#

vmap

vmap 是向量化映射;vmap(func) 返回一个新的函数,该函数将 func 映射到输入的某个维度上。

grad

grad 算子有助于计算 func 相对于由 argnums 指定的输入(们)的梯度。

grad_and_value

返回一个用于计算梯度和原始值(或前向计算)的元组的函数。

vjp

代表向量-雅可比矩阵乘积,返回一个元组,其中包含 func 应用于 primals 的结果,以及一个函数,该函数在给定 cotangents 时,计算 func 相对于 primals 的反向模式雅可比矩阵乘以 cotangents

jvp

代表雅可比矩阵-向量乘积,返回一个元组,其中包含 func(*primals) 的输出以及“在 primals 处评估的 func 的雅可比矩阵”乘以 tangents

linearize

返回 funcprimals 处的值以及在 primals 处的线性近似。

jacrev

使用反向模式自动微分计算 func 相对于 argnum 索引处的参数(们)的雅可比矩阵。

jacfwd

使用前向模式自动微分计算 func 相对于 argnum 索引处的参数(们)的雅可比矩阵。

hessian

通过前向-反向策略,计算 func 相对于索引为 argnum 的参数(们)的 Hessian。

functionalize

functionalize 是一个变换,可用于从函数中移除(中间)突变和别名,同时保留函数的语义。

用于处理 torch.nn.Modules 的实用程序#

通常,您可以变换调用 torch.nn.Module 的函数。例如,以下是计算一个接收三个值并返回三个值的函数的雅可比矩阵的示例。

model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)

但是,如果您想执行诸如计算模型参数的雅可比矩阵之类的操作,则需要一种方法来构造一个将参数作为函数输入的函数。这就是 functional_call() 的用途:它接受一个 nn.Module、转换后的 parameters 和 Module 前向传播的输入。它返回使用替换后的参数运行 Module 前向传播的值。

以下是我们如何计算参数上的雅可比矩阵。

model = torch.nn.Linear(3, 3)

def f(params, x):
    return torch.func.functional_call(model, params, x)

x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)

functional_call

通过替换提供的参数和缓冲区,在模块上执行函数式调用。

stack_module_state

为使用 vmap() 进行集成准备一系列 torch.nn.Modules。

replace_all_batch_norm_modules_

就地更新 root,通过将 running_meanrunning_var 设置为 None,并为 root 中的任何 nn.BatchNorm 模块将 track_running_stats 设置为 False。

如果您正在寻找有关修复 BatchNorm 模块的信息,请遵循此处提供的指导。

调试实用程序#

debug_unwrap

解开一个函子张量(例如。