torch.overrides#
创建于: Nov 30, 2020 | 最后更新于: Jun 06, 2025
此模块公开了用于 __torch_function__ 协议的各种辅助函数。有关 __torch_function__ 协议的更多详细信息,请参阅 扩展 torch Python API。
函数#
- torch.overrides.get_ignored_functions()[source]#
返回不能被
__torch_function__覆盖的公共函数。- 返回:
一个公共函数元组,这些函数在 torch API 中是公开可用的,但不能通过
__torch_function__覆盖。主要是因为这些函数没有一个参数是张量或类似张量的值。- 返回类型:
set[Callable]
示例
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() True >>> torch.add in torch.overrides.get_ignored_functions() False
- torch.overrides.get_overridable_functions()[source]#
列出可通过 __torch_function__ 覆盖的函数
- 返回:
一个字典,它将包含可覆盖函数的命名空间映射到该命名空间中可以被覆盖的函数。
- 返回类型:
Dict[Any, List[Callable]]
- torch.overrides.resolve_name(f)[source]#
获取传递给 __torch_function__ 的函数的易读字符串名称。
- 参数:
f (Callable) – 要解析名称的函数。
- 返回:
函数的名称;如果求值,它应该会返回输入函数。
- 返回类型:
- torch.overrides.get_testing_overrides()[source]#
返回一个字典,其中包含所有可覆盖函数的虚拟覆盖。
- 返回:
一个字典,它将 PyTorch API 中的可覆盖函数映射到具有与实际函数相同签名的 lambda 函数,并无条件地返回 -1。这些 lambda 函数对于测试支持
__torch_function__的类型的 API 覆盖率很有用。- 返回类型:
Dict[Callable, Callable]
示例
>>> import inspect >>> my_add = torch.overrides.get_testing_overrides()[torch.add] >>> inspect.signature(my_add) <Signature (input, other, out=None)>
- torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[source]#
实现一个具有
__torch_function__覆盖检查的函数。请参阅 torch::autograd::handle_torch_function 以了解此函数在 C++ 实现中的等效功能。
- 参数:
- 返回:
调用
implementation或__torch_function__方法的结果,视情况而定。- 返回类型:
:raises TypeError : 如果找不到实现。
示例
>>> def func(a): ... if has_torch_function_unary(a): ... return handle_torch_function(func, (a,), a) ... return a + 0
- torch.overrides.has_torch_function()#
检查可迭代元素中的 __torch_function__ 实现,或检查是否启用了 __torch_function__ 模式。精确的
Tensor和Parameter被认为是不可分派的。使用此函数来保护对handle_torch_function()的调用;不要使用它来测试某个对象是否是类似张量的值,而是使用is_tensor_like()。:param relevant_args: 用于检查 __torch_function__ 方法的参数的可迭代对象。:type relevant_args: iterable- 返回:
如果 relevant_args 的任何元素具有 __torch_function__ 实现,则为 True,否则为 False。
- 返回类型:
另请参阅
torch.is_tensor_like检查某个对象是否是类似张量的值,包括精确的
Tensor。
- torch.overrides.is_tensor_like(inp)[source]#
如果传入的输入是类似张量的值,则返回
True。目前,当输入类型的
__torch_function__属性存在时,就会发生这种情况。示例
张量的子类通常是类似张量的值。
>>> class SubTensor(torch.Tensor): ... >>> is_tensor_like(SubTensor([0])) True
内置类型或用户自定义类型通常不是类似张量的值。
>>> is_tensor_like(6) False >>> is_tensor_like(None) False >>> class NotATensor: ... >>> is_tensor_like(NotATensor()) False
但是,可以通过实现 __torch_function__ 来使其成为类似张量的值。
>>> class TensorLike: ... @classmethod ... def __torch_function__(cls, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True
- torch.overrides.is_tensor_method_or_property(func)[source]#
如果传入的函数是属于
torch.Tensor的方法或属性的处理器(在传递给__torch_function__时),则返回 True。注意
对于属性,必须传入它们的
__get__方法。这可能尤其需要,原因如下:
方法/属性有时不包含 __module__ 插槽。
它们要求第一个传入的参数是
torch.Tensor的实例。
示例
>>> is_tensor_method_or_property(torch.Tensor.add) True >>> is_tensor_method_or_property(torch.add) False
- 返回类型:
- torch.overrides.wrap_torch_function(dispatcher)[source]#
使用与
__torch_function__相关的函数来包装给定的函数。- 参数:
dispatcher (Callable) – 一个可调用对象,它返回传递到函数的类似张量的值的可迭代对象。
注意
此装饰器可能会降低代码的性能。通常,将代码表示为一系列支持 __torch_function__ 的函数就足够了。如果您发现自己在不常见的情况下(例如,您正在包装一个低级库,并且您还需要它对类似张量的值起作用),则可以使用此函数。
示例
>>> def dispatcher(a): # Must have the same signature as func ... return (a,) >>> @torch.overrides.wrap_torch_function(dispatcher) >>> def func(a): # This will make func dispatchable by __torch_function__ ... return a + 0