评价此页

Fake tensor#

创建日期:2023 年 5 月 19 日 | 最后更新日期:2025 年 6 月 13 日

代码: fake_tensor.py

动机#

在进行 Dynamo 符号求值和编译器传递时,我们经常希望能够运行张量运算来了解输出的尺寸/数据类型/设备,而无需实际运行这些运算(或破坏已有的张量),因为那样会更慢(如果进行大量计算)并且占用大量内存(如果您的编译器在编译程序时需要使用 GPU 内存,那将非常糟糕)。Fake tensor 在所有方面都类似于真实张量,只是它实际上没有任何数据。例如,在进行 Dynamo 追踪时,我们需要追踪用户的 Tensor 代码并回答关于中间结果的问题(例如,如果用户对中间张量进行条件判断)。没有 fake tensor,我们将无法获得这些查询的准确信息。

类似地,假设您想存储张量的元数据,例如存储在 FX IR 节点上(meta[‘val’])。您可以直接在节点上存储一个 fake tensor,它将为您提供张量所需的所有元数据,包括您可能未考虑到的细微之处(例如,别名关系)。

整体架构#

所有 fake tensor 都与 FakeTensorMode 相关联。因为 fake tensor 的主要用例是对真实张量进行分析,所以一般的工作流程是:您有一堆真实张量,分配一个 FakeTensorMode,然后使用 from_real_tensor 将所有这些真实张量转换为 fake tensor,然后对 fake tensor 进行操作。特别是,FakeTensorMode 维护一个持久的映射表,将张量(和存储)映射到相同的存储。如果您多次 fakeify 同一个张量,您将得到相同的 fake tensor;如果您 fakeify 两个相互别名的张量,您将得到两个别名相同 fake storage 的 fake tensor。FakeTensors 是 tensor subclass,因此如果您对它们进行运算,您将自动获得一个 fake tensor,但通常您希望在 FakeTensorMode 处于活动状态时对 fake tensor 进行运算(例如,如果您正在运行 FX 传递);张量运算会自动开启 fake tensor 模式并重试。

Fake tensor 表示为 meta tensor 的 __torch_dispatch__ tensor subclass。这意味着在底层,fake tensor 是 meta device tensors;然后它们使用额外的可扩展性钩子,特别是 dispatch_device,来谎报张量的实际设备。这是早期 fake tensor 中比较容易出错的部分:有时,fake tensor 会过于擅长伪装成 CPU/CUDA 等,最终会导致 CPU 内核被调用,而 fake tensor 试图解引用数据指针,这显然行不通。如果您在 fake tensor 代码中遇到段错误,这是您应该首先检查的地方:C++ 回溯是否在 CPU 内核(意外!)还是 meta 内核(预期!)中?Meta 内核类似于真实内核,但它所做的只是分配输出,而不执行任何数据计算。

Tensor subclass 必须定义如何实现各种运算。这是通用的 fake tensor 过程:

  • 在输入 fake tensors 上运行 meta 内核,将它们重新解释为 meta tensors。这是通过一个特殊的上下文管理器 in_kernel_invocation_manager 来完成的,它指示 PyTorch 将 fake tensors 视为其底层的 meta tensors,而不是“解包” fake tensors 为 meta tensors(fake tensor 是 meta tensor)。Fake tensors 的表示方式是为了避免同步两组元数据(meta tensor 的元数据和 fake tensor 的元数据);“is a”关系确保只有一组规范的元数据副本。

  • 如果您是工厂函数,则会改用 device='meta' 调用底层工厂函数。

  • 将生成的 meta tensor 转换为 fake tensor,计算张量的输出设备应该是什么(这通常是微不足道的,但有时并非如此,例如,CPU 标量提升,或设备转换运算)。

API:重要部分#

非 PT2 用途(查看 test/test_fake_tensor.py 以获取更多示例)

# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
    # Do some operations on the fake tensors
    fake_y = fake_x * 2
    # Factory operations automatically get fakeified in the context manager
    fake_z = torch.empty(20)

问:为什么输入是真实的张量?

答:在 PT2 上下文中,这是因为您通常是即时编译,因此对于要编译的图的所有输入,您已经拥有“真实”的输入,因为您在程序执行时进行编译。

PT2 AOTAutograd 之前的使用(这很不寻常,您可能不想这样做)

# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
    ... # do stuff with the fake args, if needed ...

detect_fake_mode 将在多个位置搜索以尝试找到与生命周期关联的“那个” fake tensor 模式。通常它会从追踪上下文中提取。

PT2 AOTAutograd 之后的使用

# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
fake_mode = detect_fake_mode(example_inputs)
# But in general you don't have to turn it on

其他有用信息

from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
    ... # fake mode is disabled here, you can do real tensor compute

何时可能希望禁用 fake tensor 模式?通常您不希望这样做。我们发现一个有用的细微之处是实现 fake tensor 上的常量传播:在这种情况下,即使在 fake tensor 模式下,我们也需要进行一些实际的张量计算。

import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)

细节#

自动转换还是不自动转换?最初,FakeTensorMode 不会在 FakeTensorMode 区域内尝试计算时自动 fakeify 真实张量。这样做的动机是为了防止以下“脚踏实地”的陷阱:

with FakeTensorMode():
    real_tensor.t_()

这段代码应该做什么?如果我们实际上修改了真实张量上的元数据,那将是令人惊讶的。但同时,也没有明显的创建 FakeTensor 的机会。因此,我们保守地决定引发错误:“在 FakeTensorMode 中调用具有非 Fake Tensor 输入的运算符尚不支持。请先将所有 Tensor 转换为 FakeTensors。”

在实践中,这个错误非常烦人。例如,假设您有一个真实的 nn.Module,并希望通过它传递 fake tensor。您需要某种方式来 fakeify nn.Module。这促成了 FakeCopyMode 的出现。

最终,我们放弃了,并添加了自动 fakeification。然而,在许多 FakeTensorMode 的用法中,这仍然不是默认启用的。

Fake tensor 上的元数据变异 如果您有一个 fake tensor,并且对其执行 t_(),fake tensor 上的元数据会发生变化。表面上看这是合理的,但有时您也希望将 fake tensor 作为元数据存储在 FX 节点上;变异 fake tensor 是不好的,因为它会使旧的元数据失效!

事实上,这里存在一个根本性的矛盾,即 fake tensors 维护着关于张量的极其准确的元数据,直到包括对象身份。如果 FX 图中的对象元数据随时间变化,实际上没有办法表示这种随时间的变化。大多数时候,我们对 FX 的严肃分析是在函数化图上进行的,而这些图没有这个特性,但偶尔您需要在非函数化图上进行分析。也许将 fake tensor 放入 meta[‘val’] 是个错误。

关于 tensor subclass#

Fake tensor 同时使用了 subclass 和 mode tensor subclass 模式,其中 FakeTensor.__torch_dispatch__ 启用了与 fake tensor 关联的 FakeTensorMode,然后重新调度(依赖 FakeTensorMode 来完成繁重的工作)。如果 fake tensor 运算收到一个它不识别的 subclass 参数,它将返回 NotImplemented,让另一个 subclass 有机会先运行(希望将其“脱糖”为普通张量运算),然后再重试。这可能导致无限循环。

每个单独的运算符是如何实现的?#

不幸的是,任何给定的运算符可能实现的地方非常复杂。以下是一些重要的需要了解的情况:

  • Tensor subclasses 支持有限的常量传播,如果元素数量非常少(这有助于处理一些我们立即调用 item() 的情况)。

  • 出于性能原因,我们为某些运算符提供了一些快速路径实现,这些实现完全在 fake tensor 中完成。

  • 如果您使用 @custom_op 来生成自定义张量,这些张量将直接向 fake tensor 注册 impl_abstract。

  • Fake tensor 本身对设备转换运算有一些硬编码的特殊情况。

  • 如果没有 meta 实现也没有分解,我们将生成真实的零填充张量并尝试直接运行该运算符以找出结果。这可能会导致段错误,如果该运算符尝试使用数据进行索引,因此我们默认不对自定义运算符启用此功能。

转换器是如何工作的?#

因为 fake tensor 用于对张量的精确属性非常敏感的场景,所以 fake tensor 的转换非常小心,会保留 leaf-ness、requires_grad-ness、别名以及许多其他属性。大部分繁重的工作都在 MetaConverter 中。

性能特征#

您可能会认为 fake tensor 速度很快,因为它们不进行任何张量计算。但在小张量尺寸下,我们完全受限于开销,而且,fake tensor 是用 Python 编写的,我们经常为单个张量运算执行大量工作(因为它们是作为分解实现的)。所以实际上,fake tensor 速度相当慢,尤其是在涉及符号形状时。目前 fake tensor 有两个重要的快速路径,在实践中效果显著:

  • Pointwise 运算不通过 PrimTorch 分解,而是我们手工编码了它们的传播规则。

  • 如果可能,我们应该这样做。

Fake tensor 的 fake tensor?#

人们有兴趣将 fake tensor 作为用户输入发送到 PT2 堆栈,这意味着我们需要能够创建一个 fake tensor 的 fake tensor。这目前并没有真正得到支持,但也许做起来不会太难。

与动态形状的交互#

每个 FakeTensorMode 都包含一个 ShapeEnv,用于跟踪所有符号形状信息。它们的生命周期通常是绑定的:它们一起存在,一起消亡。

因为 FakeTensorMode 有一个 ShapeEnv(但 meta 实现没有),所以依赖于数据的、需要分配未备份的 SymInt 的 meta 函数存在于 fake tensor 中。Fake tensor 还负责 memoize 未备份的 SymInt,因此,例如,如果您对同一个 fake tensor 调用 nonzero() 两次,您会得到相同的符号大小。