Fake tensor#
创建于:2023年5月19日 | 最后更新于:2025年6月13日
动机#
在进行 Dynamo 符号化评估和编译器传递时,我们经常希望能够运行张量操作来了解输出尺寸/数据类型/设备,而无需实际执行这些操作(或干扰已有的张量),因为这会更慢(如果您进行大量计算)并且占用大量内存(如果您的编译器需要在编译程序时使用 GPU 内存,那将是糟糕的)。Fake tensor 在各个方面都像真实的张量,只是它实际上不包含任何数据。例如,当我们进行 Dynamo 跟踪时,我们需要跟踪用户的张量代码并回答关于中间结果的问题(例如,如果用户在中间张量上执行条件判断)。没有 fake tensor,我们就无法获得这些查询的准确信息。
同样,假设您想为张量存储元数据,例如在 FX IR 节点上(meta[‘val’])。您可以改为直接在节点上存储 fake tensor,它将为您提供张量所需的所有元数据,包括您可能没有处理到的细微之处(例如,别名关系)。
总体架构#
所有Fake tensor都与FakeTensorMode关联。由于Fake tensor的主要用例是对真实tensor进行分析,因此一般的工作流程是:拥有大量真实tensor,分配一个FakeTensorMode,然后使用from_real_tensor将所有这些真实tensor转换为Fake tensor,之后对Fake tensor进行操作。特别地,FakeTensorMode会持久地维护一个memo表,将tensor(及存储)映射到相同的存储。如果你多次fakeify同一个tensor,你会得到相同的fake tensor;如果你fakeify两个相互别名的tensor,你会得到两个fake tensor,它们别名同一个fake存储。FakeTensors是tensor的子类,所以如果你对它们进行操作,你会自动得到一个fake tensor,但通常你希望在FakeTensorMode激活的情况下对fake tensors进行操作(例如,如果你正在运行FX pass);tensor操作会自动开启fake tensor模式并重试。
Fake tensor表示为meta tensor的__torch_dispatch__ tensor子类。这意味着在底层,fake tensors是meta device tensors;然后它们使用额外的可扩展性钩子,特别是dispatch_device,来欺骗实际tensor的设备。这是早期fake tensors中最容易出错的部分之一:有时,fake tensors在欺骗自己是CPU/CUDA等设备方面过于“优秀”,导致CPU内核被调用,而fake tensor试图解引用数据指针,这显然行不通。如果你在fake tensor代码中遇到段错误,这是你应该首先检查的地方:C++回溯是否在CPU内核(非预期!)或meta内核(预期!)中。meta内核就像一个真正的内核,但它所做的只是分配输出,它不执行任何数据计算。
tensor子类必须定义如何实现各种操作。以下是通用的fake tensor方法:
通过in_kernel_invocation_manager这个特殊的上下文管理器,在输入fake tensors上运行meta内核,并将它们重新解释为meta tensors。这个管理器会指示PyTorch将fake tensors视为其底层的meta tensors,而不是“解包”fake tensors为meta tensors(fake tensor本身就是一个meta tensor)。Fake tensors之所以这样表示,是为了避免同步两套元数据(meta tensor的元数据和fake tensor的元数据);“is a”关系确保只有一份规范的元数据副本。
如果你是一个工厂函数,你会改为调用设备设置为'meta'的底层工厂函数。
将生成的meta tensor转换为fake tensor,计算tensor的输出设备应该是什么(这通常很简单,但有时并非如此,例如CPU标量提升,或设备转换操作)。
API:重要部分#
非PT2用法(查看test/test_fake_tensor.py以获取更多示例)
# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
# Do some operations on the fake tensors
fake_y = fake_x * 2
# Factory operations automatically get fakeified in the context manager
fake_z = torch.empty(20)
问:为什么输入是真实tensor?
答:在PT2上下文中,这是因为你通常是即时编译,所以对于你要编译的图的所有输入,你已经有了“真实”的输入,因为你在程序执行时进行编译。
PT2 AOTAutograd之前的用法(这不常见,你可能不希望这样做)
# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
... # do stuff with the fake args, if needed ...
detect_fake_mode会搜索多个位置来尝试找到与生命周期关联的“那个”fake tensor模式。通常,它会从跟踪上下文中获取。
PT2 AOTAutograd之后的用法
# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
fake_mode = detect_fake_mode(example_inputs)
# But in general you don't have to turn it on
其他有用内容
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
... # fake mode is disabled here, you can do real tensor compute
何时可能需要禁用fake tensor模式?通常你不需要这样做。我们发现一个特殊的用例是实现fake tensors上的常量传播:在这种情况下,即使在fake tensor模式下,我们也需要进行一些实际的tensor计算。
import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
细节#
自动转换还是不自动转换?最初,FakeTensorMode不会在你尝试在FakeTensorMode区域内对真实tensor进行计算时自动fakeify它们。这样做的目的是防止以下“陷阱”:
with FakeTensorMode():
real_tensor.t_()
这段代码应该做什么?如果我们实际修改了真实tensor的元数据,那将是令人惊讶的。但同时,也没有明显的创建FakeTensor的机会。因此,我们保守地决定抛出一个错误:“在FakeTensorMode中使用非Fake Tensor输入调用运算符尚不支持。请先将所有Tensors转换为FakeTensors。”
这个错误在实践中非常烦人。例如,假设你有一个真实的nn.Module,你想让fake tensors通过它。你需要以某种方式fakeify nn.Module。这促使了FakeCopyMode的出现。
最终,我们放弃了,并添加了自动fakeification。然而,在许多FakeTensorMode的用法中,这仍然默认未启用。
Fake tensor上的元数据变异:如果你有一个fake tensor,然后对其进行t_()操作,fake tensor上的元数据会发生变化。表面上看这是合理的,但有时你还想将fake tensors作为元数据存储在FX节点上;变异fake tensor是糟糕的,因为这会使旧的元数据失效!
事实上,这里存在一个根本性的矛盾,即fake tensors维护着张量极其准确的元数据,包括对象身份。如果FX图中的对象元数据随时间变化,实际上没有办法表示这种随时间的变化。大多数时候,我们严肃的FX分析是在函数化图上进行的,这些图没有这个特性,但偶尔你需要在非函数化图上进行分析。也许把fake tensor放在meta[‘val’]是一个错误。
关于tensor子类#
Fake tensor同时使用了子类和模式(mode)子类模式,其中FakeTensor.__torch_dispatch__ 启用与fake tensor关联的FakeTensorMode,然后重新分派(依赖FakeTensorMode进行繁重的工作)。如果fake tensor操作收到一个它不识别的子类参数,它将返回NotImplemented,让另一个子类有机会先运行(希望将其解糖为普通tensor操作),然后再重试。这可能导致无限循环。
每个单独的运算符是如何实现的?#
不幸的是,任何给定运算符的实现都有一个相当复杂的位置集合。一些重要的需要了解的情况:
Tensor子类对数量非常少的元素支持有限的常量传播(这有助于处理我们立即调用item()的某些情况)。
我们为某些运算符提供了一些快速路径实现,这些实现完全在fake tensor中完成,出于性能原因。
如果你使用@custom_op生成自定义tensor,它们将直接向fake tensor注册impl_abstract。
Fake tensor本身对设备转换操作有一些硬编码的特殊情况。
如果没有meta实现也没有分解,我们将生成真实的零填充tensor并尝试直接运行运算符以找出结果。如果运算符尝试使用数据进行索引,这可能会导致段错误,因此我们默认不对自定义操作启用此功能。
转换器是如何工作的?#
由于fake tensors的使用场景对tensor的精确属性非常敏感,fake tensors进行转换非常小心,会保留leaf-ness、requires_grad-ness、aliasing以及许多其他属性。大部分繁重的工作由MetaConverter完成。
性能特点#
你会认为fake tensors很快,因为它们不执行任何tensor计算。但在小张量尺寸下,我们实际上完全受限于开销,而且,fake tensor是用Python实现的,我们通常会做很多工作来完成一个张量操作(因为它们是作为分解来实现的)。因此,fake tensors实际上在实践中相当慢,尤其是在涉及符号形状时。我们目前在fake tensor中有两个重要的快速路径,它们在实践中产生了很大的不同:
Pointwise ops不经过PrimTorch分解,而是我们手动编码了它们的传播规则。
如果可能,我们应该这样做。
Fake tensor of fake tensor?#
人们有兴趣将fake tensors作为用户输入发送到PT2堆栈,这意味着我们需要能够创建fake tensor of fake tensor。这目前还没有得到很好的支持,但也许并不难实现。
与动态形状的交互#
每个FakeTensorMode都包含一个ShapeEnv,它跟踪所有符号形状信息。它们的生命周期通常是绑定的:它们一起生存,一起死亡。
由于FakeTensorMode拥有ShapeEnv(而meta实现没有),因此依赖于数据的meta函数需要分配一个未绑定的SymInt,这些函数会驻留在fake tensor中。Fake tensor还负责记忆未绑定的SymInts,例如,如果你对同一个fake tensor调用两次nonzero(),你会得到相同的符号大小。