recipes/torch_compile_torch_function_modes
在 Google Colab 中运行
Colab
下载 Notebook
Notebook
在 GitHub 上查看
GitHub
注意
转到末尾 下载完整的示例代码。
(beta) 在 torch.compile 中使用 Torch 函数模式#
作者: Michael Lazos
- 本秘籍介绍如何在跟踪时将 PyTorch 的关键可扩展点——
Torch 函数模式与
torch.compile
结合使用,以覆盖 Torch 算子(也称为 ops)的行为,且没有运行时开销。
注意
本秘籍需要 PyTorch 2.7.0 或更高版本。
重写 Torch 算子 (torch.add -> torch.mul)#
在本示例中,我们将使用 Torch 函数模式将加法操作替换为乘法操作。如果某个后端具有应针对给定操作调度的自定义实现,则此类覆盖很常见。
import torch
# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys
sys.exit(0)
from torch.overrides import BaseTorchFunctionMode
# Define our mode, Note: ``BaseTorchFunctionMode``
# implements the actual invocation of func(..)
class AddToMultiplyMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if func == torch.Tensor.add:
func = torch.mul
return super().__torch_function__(func, types, args, kwargs)
@torch.compile()
def test_fn(x, y):
return x + y * x # Note: infix operators map to torch.Tensor.* methods
x = torch.rand(2, 2)
y = torch.rand_like(x)
with AddToMultiplyMode():
z = test_fn(x, y)
assert torch.allclose(z, x * y * x)
# The mode can also be used within the compiled region as well like this:
@torch.compile()
def test_fn(x, y):
with AddToMultiplyMode():
return x + y * x # Note: infix operators map to torch.Tensor.* methods
x = torch.rand(2, 2)
y = torch.rand_like(x)
z = test_fn(x, y)
assert torch.allclose(z, x * y * x)
结论#
在本秘籍中,我们演示了如何使用 torch.compile
中的 Torch 函数模式来覆盖 torch.*
算子的行为。这使得用户可以利用 Torch 函数模式的可扩展性优势,而无需在每次算子调用时都产生 Torch 函数的运行时开销。
有关其他示例和 Torch 函数模式的背景知识,请参阅 使用模式扩展 Torch API。
脚本总运行时间: (0 分钟 9.875 秒)