inference_mode#
- class torch.autograd.grad_mode.inference_mode(mode=True)[source]#
启用或禁用推理模式的上下文管理器。
InferenceMode 类似于
no_grad
,应该在您确信操作不会与 autograd 交互时使用(例如,在数据加载或模型评估期间)。与no_grad
相比,它通过禁用视图跟踪和版本计数器递增来移除额外的开销。它也更具限制性,因为在此模式下创建的张量不能用于 autograd 记录的计算。此上下文管理器是线程本地的;它不会影响其他线程中的计算。
也可作为装饰器使用。
注意
推理模式是可局部启用或禁用梯度的几种机制之一。有关比较,请参阅 局部禁用梯度计算。如果难以避免在 autograd 跟踪区域中使用在推理模式下创建的张量,请考虑基准测试您的代码是否使用推理模式,以权衡性能优势与折衷。您始终可以使用
no_grad
代替。注意
与一些局部启用或禁用梯度的其他机制不同,进入 inference_mode 也会禁用 前向模式 AD。
- 参数
mode (bool 或 function) – 要么是用于启用或禁用推理模式的布尔标志,要么是要用推理模式启用的 Python 函数装饰器。
- 示例:
>>> import torch >>> x = torch.ones(1, 2, 3, requires_grad=True) >>> with torch.inference_mode(): ... y = x * x >>> y.requires_grad False >>> y._version Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Inference tensors do not track version counter. >>> @torch.inference_mode() ... def func(x): ... return x * x >>> out = func(x) >>> out.requires_grad False >>> @torch.inference_mode() ... def doubler(x): ... return x * 2 >>> out = doubler(x) >>> out.requires_grad False