评价此页

torch.library#

创建于:2022年6月13日 | 最后更新于:2025年8月13日

torch.library 是用于扩展 PyTorch 核心算子库的 API 集合。它包含用于测试自定义算子、创建新的自定义算子以及扩展使用 PyTorch C++ 算子注册 API(例如 aten 算子)定义的算子的实用工具。

有关有效使用这些 API 的详细指南,请参阅 PyTorch 自定义算子着陆页,了解如何有效使用这些 API 的更多详细信息。

测试自定义算子#

使用 torch.library.opcheck() 测试自定义算子是否存在 Python torch.library 和/或 C++ TORCH_LIBRARY API 的不正确用法。此外,如果您的算子支持训练,请使用 torch.autograd.gradcheck() 来测试梯度是否在数学上是正确的。

torch.library.opcheck(op, args, kwargs=None, *, test_utils=('test_schema', 'test_autograd_registration', 'test_faketensor', 'test_aot_dispatch_dynamic'), raise_exception=True, atol=None, rtol=None)[source]#

给定一个算子和一些示例参数,测试该算子是否已正确注册。

也就是说,当您使用 torch.library/TORCH_LIBRARY API 创建自定义算子时,您为其指定了元数据(例如可变性信息),这些 API 要求您传递的函数满足某些属性(例如,在 fake/meta/abstract 内核中不允许访问数据指针)。opcheck 测试这些元数据和属性。

具体来说,我们测试以下内容:

  • test_schema: 模式是否与算子的实现匹配。例如:如果模式指定了一个 Tensor 被修改,那么我们检查实现是否修改了 Tensor。如果模式指定我们返回一个新的 Tensor,那么我们检查实现是否返回了一个新的 Tensor(而不是现有 Tensor 的视图)。

  • test_autograd_registration: 如果算子支持训练(autograd):我们检查其 autograd 公式是否通过 torch.library.register_autograd 或手动注册到多个 DispatchKey::Autograd 键之一进行注册。任何其他基于 DispatchKey 的注册都可能导致未定义行为。

  • test_faketensor: 如果算子具有 FakeTensor 内核(并且是正确的)。FakeTensor 内核对于算子与 PyTorch 编译 API(torch.compile/export/FX)一起使用是必需的(但不是充分条件)。我们检查算子是否注册了 FakeTensor 内核(也称为 meta 内核),并且其是否正确。此测试获取在真实张量上运行算子的结果和在 FakeTensor 上运行算子的结果,并检查它们是否具有相同的 Tensor 元数据(大小/步幅/dtype/设备等)。

  • test_aot_dispatch_dynamic: 如果算子与 PyTorch 编译 API(torch.compile/export/FX)的行为正确。这会检查在 eager 模式 PyTorch 和 torch.compile 下的输出(以及梯度,如果适用)是否相同。此测试是 test_faketensor 的超集,是一个端到端测试;它还测试算子是否支持函数化,以及反向传播(如果存在)是否也支持 FakeTensor 和函数化。

为了获得最佳结果,请多次使用一组代表性的输入调用 opcheck。如果您的算子支持 autograd,请使用 opcheck 并将 requires_grad = True 作为输入;如果您的算子支持多种设备(例如 CPU 和 CUDA),请使用 opcheck 并使用所有受支持设备上的输入。

参数:
  • op (OpOverload | OpOverloadPacket | CustomOpDef) – 算子。必须是用 torch.library.custom_op() 装饰的函数,或者在 torch.ops.* 中找到的 OpOverload/OpOverloadPacket(例如 torch.ops.aten.sin, torch.ops.mylib.foo)。

  • args (tuple[Any, ...]) – 算子的参数。

  • kwargs (dict[str, Any] | None) – 算子的关键字参数。

  • test_utils (str | Sequence[str]) – 需要运行的测试。默认值:所有测试。示例:“test_schema”, “test_faketensor”。

  • raise_exception (bool) – 是否在第一次出错时引发异常。如果为 False,我们将返回一个字典,其中包含有关每个测试是否通过的信息。

  • rtol (Optional[float]) – 浮点数比较的相对容差。如果指定,则必须同时指定 atol。如果省略,则根据 dtype 选择默认值(请参阅 torch.testing.assert_close() 中的表格)。

  • atol (Optional[float]) – 浮点数比较的绝对容差。如果指定,则必须同时指定 rtol。如果省略,则根据 dtype 选择默认值(请参阅 torch.testing.assert_close() 中的表格)。

返回类型:

dict[str, str]

警告

opcheck 和 torch.autograd.gradcheck() 测试的内容不同;opcheck 测试您对 torch.library API 的使用是否正确,而 torch.autograd.gradcheck() 测试您的 autograd 公式在数学上是否正确。使用两者来测试支持梯度计算的自定义算子。

示例

>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: float) -> Tensor:
>>>     x_np = x.numpy(force=True)
>>>     z_np = x_np * y
>>>     return torch.from_numpy(z_np).to(x.device)
>>>
>>> @numpy_mul.register_fake
>>> def _(x, y):
>>>     return torch.empty_like(x)
>>>
>>> def setup_context(ctx, inputs, output):
>>>     y, = inputs
>>>     ctx.y = y
>>>
>>> def backward(ctx, grad):
>>>     return grad * ctx.y, None
>>>
>>> numpy_mul.register_autograd(backward, setup_context=setup_context)
>>>
>>> sample_inputs = [
>>>     (torch.randn(3), 3.14),
>>>     (torch.randn(2, 3, device='cuda'), 2.718),
>>>     (torch.randn(1, 10, requires_grad=True), 1.234),
>>>     (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
>>> ]
>>>
>>> for args in sample_inputs:
>>>     torch.library.opcheck(numpy_mul, args)

在 Python 中创建新的自定义算子#

使用 torch.library.custom_op() 创建新的自定义算子。

torch.library.custom_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None, tags=None)[source]#

将函数包装成自定义算子。

您可能希望创建自定义算子的原因包括:- 包装第三方库或自定义内核以与 Autograd 等 PyTorch 子系统一起使用。- 防止 torch.compile/export/FX 跟踪窥探您的函数。

此 API 用作函数上的装饰器(请参阅示例)。提供的函数必须具有类型提示;这些类型提示对于与 PyTorch 的各种子系统进行接口至关重要。

参数:
  • name (str) – 自定义算子的名称,格式为“{namespace}::{name}”,例如“mylib::my_linear”。该名称用作算子在 PyTorch 子系统(例如 torch.export、FX 图)中的稳定标识符。为避免名称冲突,请使用您的项目名称作为命名空间;例如,pytorch/fbgemm 中的所有自定义算子都使用“fbgemm”作为命名空间。

  • mutates_args (Iterable[str] or "unknown") – 函数修改的参数名称。这必须是准确的,否则行为未定义。如果为“unknown”,则悲观地假设算子的所有输入都被修改。

  • device_types (None | str | Sequence[str]) – 函数有效的设备类型。如果未提供设备类型,则该函数用作所有设备类型的默认实现。示例:“cpu”、“cuda”。当为接受无 Tensor 的算子注册特定于设备的实现时,我们要求该算子具有“device: torch.device 参数”。

  • schema (None | str) – 算子的模式字符串。如果为 None(推荐),我们将从其类型注解中推断算子的模式。我们建议让 PyTorch 推断模式,除非您有特殊原因不这样做。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。

返回类型:

Callable[[Callable[[…], object]], CustomOpDef] | CustomOpDef

注意

我们建议不要传递 schema 参数,而是让 PyTorch 从类型注解中推断它。手动编写模式容易出错。当 PyTorch 对类型注解的解释不符合您的预期时,您可能希望提供自己的模式。有关如何编写模式字符串的更多信息,请参阅 此处

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> @custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that only works for one device type.
>>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
>>> def numpy_sin_cpu(x: Tensor) -> Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin_cpu(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that mutates an input
>>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
>>> def numpy_sin_inplace(x: Tensor) -> None:
>>>     x_np = x.numpy()
>>>     np.sin(x_np, out=x_np)
>>>
>>> x = torch.randn(3)
>>> expected = x.sin()
>>> numpy_sin_inplace(x)
>>> assert torch.allclose(x, expected)
>>>
>>> # Example of a factory function
>>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
>>> def bar(device: torch.device) -> Tensor:
>>>     return torch.ones(3)
>>>
>>> bar("cpu")
torch.library.triton_op(name, fn=None, /, *, mutates_args, schema=None)[source]#

创建一个自定义算子,其实现由 1 个或多个 triton 内核支持。

这是使用 triton 内核与 PyTorch 的一种更结构化的方式。优先使用没有 torch.library 自定义算子包装器(如 torch.library.custom_op()torch.library.triton_op())的 triton 内核,因为这样更简单;仅当您想创建一个行为类似于 PyTorch 内置算子的算子时,才使用 torch.library.custom_op()/torch.library.triton_op()。例如,您可以使用 torch.library 包装器 API 来定义 triton 内核在传递 Tensor 子类或 TorchDispatchMode 时 的行为。

使用 torch.library.triton_op() 而不是 torch.library.custom_op(),当实现由 1 个或多个 triton 内核组成时。 torch.library.custom_op() 将自定义算子视为不透明(torch.compile()torch.export.export() 永远不会跟踪它们),但 triton_op 使实现对这些子系统可见,从而允许它们优化 triton 内核。

请注意,fn 只能由 PyTorch 可理解的算子和 triton 内核调用组成。在 fn 中调用的任何 triton 内核都必须包装在 torch.library.wrap_triton() 的调用中。

参数:
  • name (str) – 自定义算子的名称,格式为“{namespace}::{name}”,例如“mylib::my_linear”。该名称用作算子在 PyTorch 子系统(例如 torch.export、FX 图)中的稳定标识符。为避免名称冲突,请使用您的项目名称作为命名空间;例如,pytorch/fbgemm 中的所有自定义算子都使用“fbgemm”作为命名空间。

  • mutates_args (Iterable[str] or "unknown") – 函数修改的参数名称。这必须是准确的,否则行为未定义。如果为“unknown”,则悲观地假设算子的所有输入都被修改。

  • schema (None | str) – 算子的模式字符串。如果为 None(推荐),我们将从其类型注解中推断算子的模式。我们建议让 PyTorch 推断模式,除非您有特殊原因不这样做。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。

返回类型:

Callable

示例

>>> import torch
>>> from torch.library import triton_op, wrap_triton
>>>
>>> import triton
>>> from triton import language as tl
>>>
>>> @triton.jit
>>> def add_kernel(
>>>     in_ptr0,
>>>     in_ptr1,
>>>     out_ptr,
>>>     n_elements,
>>>     BLOCK_SIZE: "tl.constexpr",
>>> ):
>>>     pid = tl.program_id(axis=0)
>>>     block_start = pid * BLOCK_SIZE
>>>     offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>>     mask = offsets < n_elements
>>>     x = tl.load(in_ptr0 + offsets, mask=mask)
>>>     y = tl.load(in_ptr1 + offsets, mask=mask)
>>>     output = x + y
>>>     tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> @triton_op("mylib::add", mutates_args={})
>>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
>>>     output = torch.empty_like(x)
>>>     n_elements = output.numel()
>>>
>>>     def grid(meta):
>>>         return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>>     # NB: we need to wrap the triton kernel in a call to wrap_triton
>>>     wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
>>>     return output
>>>
>>> @torch.compile
>>> def f(x, y):
>>>     return add(x, y)
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>>
>>> z = f(x, y)
>>> assert torch.allclose(z, x + y)
torch.library.wrap_triton(triton_kernel, /)[source]#

允许通过 make_fx 或非严格 torch.export 将 triton 内核捕获到图中。

这些技术执行基于 Dispatcher 的跟踪(通过 __torch_dispatch__),无法看到对原始 triton 内核的调用。 wrap_triton API 将 triton 内核包装成一个可调用的对象,该对象实际上可以被跟踪到图中。

请与 torch.library.triton_op() 一起使用此 API。

示例

>>> import torch
>>> import triton
>>> from triton import language as tl
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch.library import wrap_triton
>>>
>>> @triton.jit
>>> def add_kernel(
>>>     in_ptr0,
>>>     in_ptr1,
>>>     out_ptr,
>>>     n_elements,
>>>     BLOCK_SIZE: "tl.constexpr",
>>> ):
>>>     pid = tl.program_id(axis=0)
>>>     block_start = pid * BLOCK_SIZE
>>>     offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>>     mask = offsets < n_elements
>>>     x = tl.load(in_ptr0 + offsets, mask=mask)
>>>     y = tl.load(in_ptr1 + offsets, mask=mask)
>>>     output = x + y
>>>     tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> def add(x, y):
>>>     output = torch.empty_like(x)
>>>     n_elements = output.numel()
>>>
>>>     def grid_fn(meta):
>>>         return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>>     wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
>>>     return output
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>> gm = make_fx(add)(x, y)
>>> print(gm.code)
>>> # def forward(self, x_1, y_1):
>>> #     empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
>>> #     triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
>>> #         kernel_idx = 0, constant_args_idx = 0,
>>> #         grid = [(1, 1, 1)], kwargs = {
>>> #             'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
>>> #             'n_elements': 3, 'BLOCK_SIZE': 16
>>> #         })
>>> #     return empty_like
返回类型:

任何

扩展自定义算子(用 Python 或 C++ 创建)#

使用 register.* 方法,例如 torch.library.register_kernel()torch.library.register_fake(),为任何算子(它们可能已使用 torch.library.custom_op() 或通过 PyTorch 的 C++ 算子注册 API 创建)添加实现。

torch.library.register_kernel(op, device_types, func=None, /, *, lib=None)[source]#

为该算子的特定设备类型注册一个实现。

一些有效的 device_types 是:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。此 API 可用作装饰器。

参数:
  • op (str | OpOverload) – 要注册实现的算子。

  • device_types (None | str | Sequence[str]) – 要注册实现的设备类型。如果为 None,我们将注册到所有设备类型——请仅在您的实现确实与设备类型无关时使用此选项。

  • func (Callable) – 注册为给定设备类型实现的函数。

  • lib (Optional[Library]) – 如果提供,则此注册的生命周期

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> # Create a custom op that works on cpu
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>>
>>> # Add implementations for the cuda device
>>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
>>> def _(x):
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x_cpu = torch.randn(3)
>>> x_cuda = x_cpu.cuda()
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
torch.library.register_autocast(op, device_type, cast_inputs, /, *, lib=None)[source]#

为此自定义算子注册一个 autocast 分派规则。

有效的 device_type 包括:“cpu”和“cuda”。

参数:
  • op (str | OpOverload) – 要注册 autocast 分派规则的算子。

  • device_type (str) – 要使用的设备类型。‘cuda’或‘cpu’。该类型与 torch.devicetype 属性相同。因此,您可以使用 Tensor.device.type 获取张量的设备类型。

  • cast_inputs (torch.dtype) – 当自定义算子在启用了 autocast 的区域中运行时,会将传入的浮点 Tensor 转换为目标 dtype(非浮点 Tensor 不受影响),然后执行自定义算子并禁用 autocast。

  • lib (Optional[Library]) – 如果提供,则此注册的生命周期

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>>
>>> # Create a custom op that works on cuda
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
>>> def my_sin(x: Tensor) -> Tensor:
>>>     return torch.sin(x)
>>>
>>> # Register autocast dispatch rule for the cuda device
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
>>>
>>> x = torch.randn(3, dtype=torch.float32, device="cuda")
>>> with torch.autocast("cuda", dtype=torch.float16):
>>>     y = torch.ops.mylib.my_sin(x)
>>> assert y.dtype == torch.float16
torch.library.register_autograd(op, backward, /, *, setup_context=None, lib=None)[source]#

为此自定义算子注册一个后向公式。

为了让算子能够与 autograd 一起使用,您需要注册一个后向公式:1. 您必须通过提供一个“backward”函数来告诉 PyTorch 如何在后向传播中计算梯度。2. 如果您需要前向传播的任何值来计算梯度,您可以使用 setup_context 来保存值以供后向传播使用。

backward 在后向传播期间运行。它接受 (ctx, *grads):- grads 是一个或多个梯度。梯度的数量与算子的输出数量匹配。ctx 对象是 torch.autograd.Function 使用的 ctx 对象相同的对象。 backward_fn 的语义与 torch.autograd.Function.backward() 相同。

setup_context(ctx, inputs, output) 在前向传播期间运行。请将后向传播所需的值通过 torch.autograd.function.FunctionCtx.save_for_backward() 保存到 ctx 对象,或将它们作为 ctx 的属性赋值。如果您的自定义算子有仅限关键字参数,我们期望 setup_context 的签名是 setup_context(ctx, inputs, keyword_only_inputs, output)

setup_context_fnbackward_fn 都必须是可跟踪的。也就是说,它们不能直接访问 torch.Tensor.data_ptr(),并且它们不能依赖或修改全局状态。如果您需要一个不可跟踪的后向函数,您可以将其制作成一个单独的 custom_op,并在 backward_fn 中调用它。

如果您需要在不同设备上具有不同的 autograd 行为,那么我们建议创建两个不同的自定义算子,一个用于需要不同行为的每个设备,并在运行时在它们之间进行切换。

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, output) -> Tensor:
>>>     x, = inputs
>>>     ctx.save_for_backward(x)
>>>
>>> def backward(ctx, grad):
>>>     x, = ctx.saved_tensors
>>>     return grad * x.cos()
>>>
>>> torch.library.register_autograd(
...     "mylib::numpy_sin", backward, setup_context=setup_context
... )
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_sin(x)
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, x.cos())
>>>
>>> # Example with a keyword-only arg
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = x_np * val
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
>>>     ctx.val = keyword_only_inputs["val"]
>>>
>>> def backward(ctx, grad):
>>>     return grad * ctx.val
>>>
>>> torch.library.register_autograd(
...     "mylib::numpy_mul", backward, setup_context=setup_context
... )
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_mul(x, val=3.14)
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
torch.library.register_fake(op, func=None, /, *, lib=None, _stacklevel=1, allow_override=False)[source]#

为此算子注册一个 FakeTensor 实现(“fake impl”)。

也称为“meta kernel”、“abstract impl”。

“FakeTensor 实现”指定了该算子在不包含数据的 Tensor(“FakeTensor”)上的行为。给定具有特定属性(大小/步幅/storage_offset/设备)的输入 Tensor,它指定输出 Tensor 的属性。

FakeTensor 实现具有与算子相同的签名。它同时用于 FakeTensor 和 meta Tensor。要编写 FakeTensor 实现,请假定算子的所有 Tensor 输入都是常规的 CPU/CUDA/Meta Tensor,但它们没有存储,并且您正试图将常规 CPU/CUDA/Meta Tensor 作为输出返回。FakeTensor 实现只能由 PyTorch 操作组成(并且不能直接访问任何输入或中间 Tensor 的存储或数据)。

此 API 可用作装饰器(请参阅示例)。

有关自定义算子的详细指南,请参阅 https://pytorch.ac.cn/tutorials/advanced/custom_ops_landing_page.html

参数:
  • op_name – 算子名称(包括重载)或 OpOverload 对象。

  • func (Callable | None) – Fake Tensor 实现。

  • lib (Optional[Library]) – 要注册 fake Tensor 的库。

  • allow_override (bool) – 控制是否覆盖现有已注册 fake impl 的标志。此选项默认关闭,如果您尝试向已具有 fake impl 的算子注册 fake impl,则会报错。这仅适用于未通过 torch.library.custom_op 创建的自定义算子,因为覆盖现有 fake impl 已经允许。

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Example 1: an operator without data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
>>>     raise NotImplementedError("Implementation goes here")
>>>
>>> @torch.library.register_fake("mylib::custom_linear")
>>> def _(x, weight, bias):
>>>     assert x.dim() == 2
>>>     assert weight.dim() == 2
>>>     assert bias.dim() == 1
>>>     assert x.shape[1] == weight.shape[1]
>>>     assert weight.shape[0] == bias.shape[0]
>>>     assert x.device == weight.device
>>>
>>>     return (x @ weight.t()) + bias
>>>
>>> with torch._subclasses.fake_tensor.FakeTensorMode():
>>>     x = torch.randn(2, 3)
>>>     w = torch.randn(3, 3)
>>>     b = torch.randn(3)
>>>     y = torch.ops.mylib.custom_linear(x, w, b)
>>>
>>> assert y.shape == (2, 3)
>>>
>>> # Example 2: an operator with data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
>>> def custom_nonzero(x: Tensor) -> Tensor:
>>>     x_np = x.numpy(force=True)
>>>     res = np.stack(np.nonzero(x_np), axis=1)
>>>     return torch.tensor(res, device=x.device)
>>>
>>> @torch.library.register_fake("mylib::custom_nonzero")
>>> def _(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an fake impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>>     ctx = torch.library.get_ctx()
>>>     nnz = ctx.new_dynamic_size()
>>>     shape = [nnz, x.dim()]
>>>     result = x.new_empty(shape, dtype=torch.int64)
>>>     return result
>>>
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>>
>>> x = torch.tensor([0, 1, 2, 3, 4, 0])
>>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
>>> trace.print_readable()
>>>
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
torch.library.register_vmap(op, func=None, /, *, lib=None)[source]#

注册一个 vmap 实现以支持此自定义算子的 torch.vmap()

此 API 可用作装饰器(请参阅示例)。

为了让算子能够与 torch.vmap() 一起使用,您可能需要注册一个 vmap 实现,其签名如下:

vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs),

其中 *args**kwargsop 的参数和关键字参数。我们不支持仅限关键字的 Tensor 参数。

它指定了在输入具有附加维度(由 in_dims 指定)时,如何计算 op 的批处理版本。

对于 args 中的每个参数,in_dims 都有一个对应的 Optional[int]。如果该参数不是 Tensor,或者该参数未被 vmap 处理,则为 None;否则,它是一个整数,指定了 Tensor 的哪个维度正在被 vmap 处理。

info 是一组附加元数据,可能很有用:info.batch_size 指定了正在 vmap 处理的维度的大小,而 info.randomness 是传递给 torch.vmap()randomness 选项。

函数 func 的返回值为一个元组 (output, out_dims)。与 in_dims 类似,out_dims 应与 output 具有相同的结构,并为每个输出包含一个 out_dim,该 out_dim 指定输出是否具有 vmapped 维度以及它在其中的索引。

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>> from typing import Tuple
>>>
>>> def to_numpy(tensor):
>>>     return tensor.cpu().numpy()
>>>
>>> lib = torch.library.Library("mylib", "FRAGMENT")
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
>>>     x_np = to_numpy(x)
>>>     dx = torch.tensor(3 * x_np ** 2, device=x.device)
>>>     return torch.tensor(x_np ** 3, device=x.device), dx
>>>
>>> def numpy_cube_vmap(info, in_dims, x):
>>>     result = numpy_cube(x)
>>>     return result, (in_dims[0], in_dims[0])
>>>
>>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
>>>
>>> x = torch.randn(3)
>>> torch.vmap(numpy_cube)(x)
>>>
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
>>>     return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
>>>
>>> @torch.library.register_vmap("mylib::numpy_mul")
>>> def numpy_mul_vmap(info, in_dims, x, y):
>>>     x_bdim, y_bdim = in_dims
>>>     x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
>>>     y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
>>>     result = x * y
>>>     result = result.movedim(-1, 0)
>>>     return result, 0
>>>
>>>
>>> x = torch.randn(3)
>>> y = torch.randn(3)
>>> torch.vmap(numpy_mul)(x, y)

注意

vmap 函数应旨在保留整个自定义运算符的语义。也就是说,grad(vmap(op)) 应该可以被 grad(map(op)) 替换。

如果您的自定义运算符在反向传播中具有任何自定义行为,请牢记这一点。

torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[source]#

此 API 在 PyTorch 2.4 中已重命名为 torch.library.register_fake()。请改用该 API。

torch.library.get_ctx()[source]#

get_ctx() 返回当前的 AbstractImplCtx 对象。

调用 get_ctx() 仅在 fake impl 内部有效(有关更多用法详细信息,请参阅 torch.library.register_fake())。

返回类型:

FakeImplCtx

torch.library.register_torch_dispatch(op, torch_dispatch_class, func=None, /, *, lib=None)[source]#

为给定的运算符和 torch_dispatch_class 注册一个 torch_dispatch 规则。

这允许开放式注册来指定运算符与 torch_dispatch_class 之间的行为,而无需直接修改 torch_dispatch_class 或运算符。

torch_dispatch_class 是一个具有 __torch_dispatch__ 的 Tensor 子类,或者是一个 TorchDispatchMode。

如果是 Tensor 子类,我们期望 func 具有以下签名: (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any

如果是 TorchDispatchMode,我们期望 func 具有以下签名: (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any

argskwargs 将会像在 __torch_dispatch__ 中一样被标准化(请参阅 __torch_dispatch__ 调用约定)。

示例

>>> import torch
>>>
>>> @torch.library.custom_op("mylib::foo", mutates_args={})
>>> def foo(x: torch.Tensor) -> torch.Tensor:
>>>     return x.clone()
>>>
>>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
>>>     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
>>>         return func(*args, **kwargs)
>>>
>>> @torch.library.register_torch_dispatch("mylib::foo", MyMode)
>>> def _(mode, func, types, args, kwargs):
>>>     x, = args
>>>     return x + 1
>>>
>>> x = torch.randn(3)
>>> y = foo(x)
>>> assert torch.allclose(y, x)
>>>
>>> with MyMode():
>>>     y = foo(x)
>>> assert torch.allclose(y, x + 1)
torch.library.infer_schema(prototype_function, /, *, mutates_args, op_name=None)[source]#

解析给定函数的模式(schema),该函数具有类型提示。模式是从函数的类型提示推断出来的,可用于定义新运算符。

我们做出以下假设:

  • 没有输出会别名任何输入或彼此。

  • 字符串类型注解“device, dtype, Tensor, types”,如果没有指定库,
    则假定为 torch.*。类似地,字符串类型注解“Optional, List, Sequence, Union”,
    如果没有指定库,则假定为 typing.*。
  • 只有 mutates_args 中列出的参数会被修改。如果 mutates_args 是“unknown”,
    则假定运算符的所有输入都会被修改。

调用者(例如,custom ops API)负责检查这些假设。

参数:
  • prototype_function (Callable) – 要从中推断模式的函数(通过其类型注解)。

  • op_name (Optional[str]) – 模式中的运算符名称。如果 name 为 None,则模式中不包含该名称。请注意,torch.library.Library.define 的输入模式需要运算符名称。

  • mutates_args ("unknown" | Iterable[str]) – 函数中被修改的参数。

返回:

推断出的模式。

返回类型:

str

示例

>>> def foo_impl(x: torch.Tensor) -> torch.Tensor:
>>>     return x.sin()
>>>
>>> infer_schema(foo_impl, op_name="foo", mutates_args={})
foo(Tensor x) -> Tensor
>>>
>>> infer_schema(foo_impl, mutates_args={})
(Tensor x) -> Tensor
class torch._library.custom_ops.CustomOpDef(namespace, name, schema, fn, tags=None)[source]#

CustomOpDef 是一个函数包装器,它将函数转换为自定义运算符。

它有各种方法可以为该自定义运算符注册额外的行为。

您不应该直接实例化 CustomOpDef;相反,请使用 torch.library.custom_op() API。

set_kernel_enabled(device_type, enabled=True)[source]#

禁用或重新启用此自定义运算符的已注册内核。

如果内核已禁用/启用,则此操作无效果。

注意

如果先禁用内核然后注册,则该内核处于禁用状态,直到再次启用。

参数:
  • device_type (str) – 要禁用/启用内核的设备类型。

  • disable (bool) – 是禁用还是启用内核。

示例

>>> inp = torch.randn(1)
>>>
>>> # define custom op `f`.
>>> @custom_op("mylib::f", mutates_args=())
>>> def f(x: Tensor) -> Tensor:
>>>     return torch.zeros(1)
>>>
>>> print(f(inp))  # tensor([0.]), default kernel
>>>
>>> @f.register_kernel("cpu")
>>> def _(x):
>>>     return torch.ones(1)
>>>
>>> print(f(inp))  # tensor([1.]), CPU kernel
>>>
>>> # temporarily disable the CPU kernel
>>> with f.set_kernel_enabled("cpu", enabled = False):
>>>     print(f(inp))  # tensor([0.]) with CPU kernel disabled
torch.library.get_kernel(op, dispatch_key)[source]#

返回给定运算符和分派键的已计算内核。

此函数检索将为给定的运算符和分派键组合执行的内核。返回的 SafeKernelFunction 可用于以装箱方式调用内核。此函数的预期用途是检索给定分派键的原始内核,然后为同一分派键注册另一个内核,该内核在某些情况下会调用原始内核。

参数:
  • op (str | OpOverload | CustomOpDef) – 运算符名称(连同重载)或 OpOverload 对象。可以是字符串(例如,“aten::add.Tensor”)、OpOverload 或 CustomOpDef。

  • dispatch_key (str | torch.DispatchKey) – 用于获取内核的分派键。可以是字符串(例如,“CPU”、“CUDA”)或 DispatchKey 枚举值。

返回:

一个安全的内核函数,可用于

调用内核。

返回类型:

torch._C._SafeKernelFunction

抛出:

RuntimeError – 如果运算符不存在。

示例

>>> # Get the CPU kernel for torch.add
>>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU")
>>>
>>> # You can also use DispatchKey enum
>>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU)
>>>
>>> # Or use an OpOverload directly
>>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU")
>>>
>>> # Example: Using get_kernel in a custom op with conditional dispatch
>>> # Get the original kernel for torch.sin
>>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU")
>>>
>>> # If input has negative values, use original sin, otherwise return zeros
>>> def conditional_sin_impl(dispatch_keys, x):
>>>     if (x < 0).any():
>>>         return original_sin_kernel.call_boxed(dispatch_keys, x)
>>>     else:
>>>         return torch.zeros_like(x)
>>>
>>> lib = torch.library.Library("aten", "IMPL")
>>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet
>>> which needs to be the first argument to ``kernel.call_boxed``
>>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True)
>>>
>>> # Test the conditional behavior
>>> x_positive = torch.tensor([1.0, 2.0])
>>> x_mixed = torch.tensor([-1.0, 2.0])
>>> torch.sin(x_positive)
tensor([0., 0.])
>>> torch.sin(x_mixed)
tensor([-0.8415, 0.9093])

底层 API#

以下 API 是 PyTorch C++ 底层运算符注册 API 的直接绑定。

警告

底层运算符注册 API 和 PyTorch Dispatcher 是一个复杂的 PyTorch 概念。我们建议您在可能的情况下使用上述更高级别的 API(这些 API 不需要 torch.library.Library 对象)。这篇博文是了解 PyTorch Dispatcher 的一个好起点。

有关如何使用此 API 的一些示例教程可在 Google Colab 上找到。

class torch.library.Library(ns, kind, dispatch_key='')[source]#

一个用于创建库的类,这些库可用于从 Python 注册新运算符或覆盖现有库中的运算符。用户可以选择传入一个分派键名,如果他们只想注册对应于单个特定分派键的内核。

要创建一个用于覆盖现有库(名称为 ns)中运算符的库,请将 kind 设置为“IMPL”。要创建一个新库(名称为 ns)来注册新运算符,请将 kind 设置为“DEF”。要创建一个可能存在的库的片段,用于注册运算符(并绕过给定命名空间只有一个库的限制),请将 kind 设置为“FRAGMENT”。

参数:
  • ns – 库名

  • kind – “DEF”、“IMPL”、“FRAGMENT”

  • dispatch_key – PyTorch 分派键(默认:“”)

define(schema, alias_analysis='', *, tags=())[source]#

在 ns 命名空间中定义新运算符及其语义。

参数:
  • schema – 用于定义新运算符的函数模式。

  • alias_analysis (optional) – 指示是否可以从模式(默认行为)推断运算符参数的别名属性,或者是否不能(“CONSERVATIVE”)。

  • tags (Tag | Sequence[Tag]) – 要应用于此运算符的一个或多个 torch.Tag。标记运算符会更改运算符在各种 PyTorch 子系统下的行为;在应用之前,请仔细阅读 torch.Tag 的文档。

返回:

从模式推断出的运算符名称。

示例

>>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
fallback(fn, dispatch_key='', *, with_keyset=False)[source]#

将函数实现注册为给定键的回退。

此函数仅适用于具有全局命名空间的库(“_”)。

参数:
  • fn – 用作给定分派键的回退的函数,或者 fallthrough_kernel() 来注册一个回退。

  • dispatch_key – 输入函数应注册的分派键。默认情况下,它使用创建库时使用的分派键。

  • with_keyset – 控制标志,指示在调用 fn 时是否将当前分派器调用键集作为第一个参数传递。这应该用于创建用于重新分派调用的适当键集。

示例

>>> my_lib = Library("_", "IMPL")
>>> def fallback_kernel(op, *args, **kwargs):
>>>     # Handle all autocast ops generically
>>>     # ...
>>> my_lib.fallback(fallback_kernel, "Autocast")
impl(op_name, fn, dispatch_key='', *, with_keyset=False, allow_override=False)[source]#

为库中定义的运算符注册函数实现。

参数:
  • op_name – 运算符名称(连同重载)或 OpOverload 对象。

  • fn – 作为输入分派键的运算符函数,或者 fallthrough_kernel() 来注册回退。

  • dispatch_key – 输入函数应注册的分派键。默认情况下,它使用创建库时使用的分派键。

  • with_keyset – 控制标志,指示在调用 fn 时是否将当前分派器调用键集作为第一个参数传递。这应该用于创建用于重新分派调用的适当键集。

  • allow_override – 控制是否覆盖现有已注册内核实现的标志。此选项默认为关闭,如果您尝试向已注册内核的分派键注册内核,将会报错。

示例

>>> my_lib = Library("aten", "IMPL")
>>> def div_cpu(self, other):
>>>     return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
torch.library.fallthrough_kernel()[source]#

一个虚拟函数,用于传递给 Library.impl 以注册回退。

torch.library.define(qualname, schema, *, lib=None, tags=())[source]#
torch.library.define(lib, schema, alias_analysis='')

定义一个新运算符。

在 PyTorch 中,定义一个 op(“operator”的缩写)是一个两步过程:- 我们需要定义 op(提供运算符名称和模式)- 我们需要实现运算符如何与各种 PyTorch 子系统(如 CPU/CUDA Tensor、Autograd 等)交互的行为。

此入口点定义自定义运算符(第一步),您必须通过调用各种 impl_* API(如 torch.library.impl()torch.library.register_fake())来执行第二步。

参数:
  • qualname (str) – 运算符的限定名称。应为一个看起来像“namespace::name”的字符串,例如“aten::sin”。PyTorch 中的运算符需要一个命名空间来避免名称冲突;给定的运算符只能创建一次。如果您正在编写 Python 库,我们建议命名空间是您的顶层模块的名称。

  • schema (str) – 运算符的模式。例如,对于接受一个 Tensor 并返回一个 Tensor 的 op,为“(Tensor x) -> Tensor”。它不包含运算符名称(该名称在 qualname 中传递)。

  • lib (Optional[Library]) – 如果提供,此运算符的生命周期将与 Library 对象绑定。

  • tags (Tag | Sequence[Tag]) – 要应用于此运算符的一个或多个 torch.Tag。标记运算符会更改运算符在各种 PyTorch 子系统下的行为;在应用之前,请仔细阅读 torch.Tag 的文档。

示例:
>>> import torch
>>> import numpy as np
>>>
>>> # Define the operator
>>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the operator
>>> @torch.library.impl("mylib::sin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Call the new operator from torch.ops.
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.sin(x)
>>> assert torch.allclose(y, x.sin())
torch.library.impl(lib, name, dispatch_key='')[source]#
torch.library.impl(qualname: str, types: str | Sequence[str], func: None = None, *, lib: Library | None = None) Callable[[Callable[..., object]], None]
torch.library.impl(qualname: str, types: str | Sequence[str], func: Callable[..., object], *, lib: Library | None = None) None
torch.library.impl(lib: Library, name: str, dispatch_key: str = '') Callable[[Callable[_P, _T]], Callable[_P, _T]]

为该算子的特定设备类型注册一个实现。

您可以将“default”作为 types 的值来将此实现注册为所有设备类型的默认实现。请仅在实现确实支持所有设备类型时才使用此选项;例如,如果它是内置 PyTorch 运算符的组合,则情况就是如此。

此 API 可用作装饰器。您可以使用此 API 的嵌套装饰器,前提是它们返回一个函数并放置在该 API 内部(请参阅示例 2)。

一些有效的类型是:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。

参数:
  • qualname (str) – 应为看起来像“namespace::operator_name”的字符串。

  • types (str | Sequence[str]) – 要注册 impl 的设备类型。

  • lib (可选[Library]) – 如果提供,此注册的生命周期将与 Library 对象的生命周期绑定。

示例

>>> import torch
>>> import numpy as np
>>> # Example 1: Register function.
>>> # Define the operator
>>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the cpu device
>>> @torch.library.impl("mylib::mysin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.mysin(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example 2: Register function with decorator.
>>> def custom_decorator(func):
>>>     def wrapper(*args, **kwargs):
>>>         return func(*args, **kwargs) + 1
>>>     return wrapper
>>>
>>> # Define the operator
>>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the operator
>>> @torch.library.impl("mylib::sin_plus_one", "cpu")
>>> @custom_decorator
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Call the new operator from torch.ops.
>>> x = torch.randn(3)
>>>
>>> y1 = torch.ops.mylib.sin_plus_one(x)
>>> y2 = torch.sin(x) + 1
>>> assert torch.allclose(y1, y2)