动态形状#
创建于:2023年5月19日 | 最后更新于:2025年6月10日
另请参阅:动态形状手册
动机#
深度学习编译器通常只支持静态形状,也就是说,它们生成的编译程序只能用于一个特定的输入形状配置,并且在任何输入形状发生变化时都需要重新编译。这个假设对于当今绝大多数常用的深度学习模型来说效果很好,但在某些情况下是不够的。
某些维度,例如批次大小或序列长度,可能会发生变化。例如,一个执行自适应批处理的推理服务将根据其批处理窗口接收的请求数量,使用不同的批次大小执行推理请求。我们也可能希望仅填充批次内的可变大小序列到最大序列长度,这可能因批次而异。
一些模型表现出数据依赖的输出形状,也就是说,它们的输出和中间结果的大小可能取决于实际输入数据,而这在运行时可能会有所不同。例如,检测模型可能首先生成可变数量的潜在边界框,然后再运行一个更耗时的图像识别模型来识别边界框中的主题。边界框的数量是数据依赖的。
当处理稀疏表示时,例如稀疏张量、交错张量和图神经网络,数据依赖形状的一个特别重要的案例就会出现。在所有这些情况下,要处理的数据量取决于问题的稀疏结构,而这通常会以数据依赖的方式发生变化。
在支持动态形状时,我们选择不支持动态秩程序,例如输入张量的维度会发生变化的程序,因为这种模式在现实世界的深度学习程序中很少出现,而且它避免了对形状符号列表进行归纳推理的需要。
精简版公共API#
PyTorch 2.1 中默认的动态行为是
PT2 默认假定所有内容都是静态的
如果我们因为大小改变而重新编译,我们将尝试将该大小重新编译为动态(已更改的大小很可能在将来再次更改)。这种泛化可能会失败(例如,因为用户代码在该大小上进行了条件分支,或者 PT2 中缺少动态形状支持)。如果您想了解 PT2 何时对某些代码进行了过度特化,请使用
TORCH_LOGS=dynamic
运行,并查找“eval”条目,了解何时添加了 guard 以及原因。如果您提前知道某个尺寸是动态的,可以使用
torch._dynamo.mark_dynamic(tensor, dim)
来跳过第一次重新编译。如果您提前知道该维度可以采取的min
和max
值,则可以指定torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)
。如果您指定
torch.compile(dynamic=False)
,我们将关闭重新编译时的自动动态形状,并始终为每个不同的尺寸进行重新编译。反之,如果您指定torch.compile(dynamic=True)
,我们将尽可能使所有内容都动态化。这主要对小型操作符有用;如果您尝试将其应用于大型模型,它将(1)可能导致 PT2 崩溃,并且(2)无缘无故地运行缓慢。您可以使用
TORCH_COMPILE_DYNAMIC_SOURCES
环境变量或通过设置torch.compiler.config.dynamic_sources
来白名单特定源以标记为动态。这对于具有图断点的模型特别有用,因为您可以跨图断点保持动态性,因为源名称保持一致。您也可以使用此来标记整数为动态。格式是逗号分隔的源名称列表,例如"L['x'], L['y']"
。您还可以使用正则表达式,例如"L\['x.*'\], L\['y.*'\]")
。此白名单优先于其他标志,例如dynamic=False
、force_nn_module_property_static_shapes
和force_parameter_static_shapes
。有时找到要标记为动态的正确输入可能很麻烦。如果您愿意在第一个批次上承受性能损失,我们提供的另一个可行的选择是 eager_then_compile 模式,该模式会为您推导动态性。有关更多详细信息,请参阅 torch.compiler.set_stance。
Guard 模型#
在考虑如何为 TorchDynamo 和 TorchInductor 添加动态形状支持时,我们做了一个重要的设计决策:为了重用针对 PyTorch API 编写的分解和其他现有代码,我们必须能够跟踪动态形状。与可能捕获条件分支的完全符号化系统不同,我们总是选择一个分支,并在假设我们将来会为该分支做出相同选择的情况下对我们的跟踪进行特化。为此,我们为每个符号化大小维护一个“提示”,说明其在编译时具体的取值(由于 TorchDynamo 是一个即时编译器,它总是知道实际的输入大小)。当我们对张量进行条件判断时,我们只需查阅提示即可确定采用哪个分支。
这大大简化了我们生成的符号化形状公式,但也意味着我们有一个更复杂的系统来管理 guards。例如,考虑以下程序
def f(x, y):
z = torch.cat([x, y])
if z.size(0) > 2:
return z.mul(2)
else:
return z.add(2)
我们将使用 TorchInductor 编译的最终 IR 将是 torch.cat([x, y]).add(2)
或 torch.cat([x, y]).mul(2)
(条件已展平),但要确定我们处于哪个分支,我们需要知道 z
的大小,这是一个中间结果。由于 TorchDynamo 必须预先知道编译的跟踪是否有效(我们不支持像某些 JIT 编译器那样的逃逸机制),我们必须能够将 z.size(0)
简化为输入表达式,即 x.size(0) + y.size(0)
。这是通过为 PyTorch 中的所有操作符编写元函数来实现的,这些函数可以在不实际执行节点计算的情况下将大小信息传播到张量的输出。
整体架构#
符号形状工作流
当我们在 Dynamo 中开始编译一个帧时,我们会分配一个 ShapeEnv(附加到 FakeTensorMode),它负责跟踪符号形状状态。
我们在张量进入时为其分配符号大小(静态还是动态是一个策略决策,有一些可调的旋钮)。
我们通过操作符传播符号大小,同时维护(1)FX IR,以便我们能够忠实地导出符号计算,以及(2)代表大小变量的 Sympy 表达式,以便我们可以对其进行推理。
当我们根据符号大小进行条件判断时,无论是在 Dynamo 跟踪中还是在 Inductor 优化中,我们都会根据条件添加 guards。这些 guards 可以从 Python 和 C++ 中诱导。
这些 guards 可以诱导对符号变量进行进一步的简化。例如,如果您断言
s0 == 4
,我们现在可以将s0
的所有出现都替换为4
。当我们完成跟踪和优化时,我们将所有这些 guards 安装到编译后的代码中;只有当所有 guards 的评估结果都为 true 时,编译后的代码才是可重用的。
重要文件
C++ SymInt API:
c10/core/SymInt.h
、SymFloat.h
、SymBool.h
Python SymInt API:
torch/__init__.py
(查找SymInt/SymFloat/SymBool
)C++ 管道:
c10/core/SymNodeImpl.h
、torch/csrc/utils/python_symnode.h
、torch/csrc/jit/python/init.cpp
Python 基础设施:
torch/fx/experimental/symbolic_shapes.py
其他重要文件:
torch/_subclasses/fake_tensor.py
、torch/_meta_registrations.py
、decomps、PrimTorch refs
精简版内部 API#
理解 Python 类层次结构
SymInt/SymFloat/SymBool:这些是用户可见的类,模拟了它们的 int/float/bool 对等项。如果您将两个 SymInt 相加,我们将为您提供一个新的 SymInt,它在符号上跟踪整数加法已发生。
SymNode:这是内部结构(可通过例如
symint.node
访问),它保存实际的符号跟踪信息。SymNode 是经过类型擦除的;这使得表示混合类型操作更加方便。请注意,技术上您不必从 SymInt 调用 Python SymNode;例如,XLA 的 C++SymNodeImpl
将取代 SymNode。ShapeEnv:每个编译的上下文状态,它跟踪我们到目前为止累积的所有自由符号和保护。每个 SymNode 都记录其 ShapeEnv(但反之则不然;仅当 SymNodes 参与保护时才会使用它们)。
C++ 相当类似
c10::SymInt/SymFloat/SymBool:用户可见的类,模拟 int/float/bool。
c10::SymNode/SymNodeImpl:类似于 SymNode
C++ 中没有 ShapeEnv;为了便于调试,整个符号推理机制都在 Python 中。
当您编写使用 make_fx
可跟踪的代码时,它必须能够处理其中流动的 SymInt/SymFloat/SymBool。 动态形状手册 提供了一些如何执行此操作的指导。
未备份的 SymInts#
为了解决控制流,我们检查符号整数的提示,即实际值,以确定分支。但是,在某些情况下,我们可能没有提示:所谓的未备份符号整数是在数据依赖操作(如 .nonzero()
或 .item()
)中出现大小变量时产生的。对这些符号整数执行控制流是非法的,因此我们必须在这些操作上进行图中断。
天真地实现,这过于严格:如果您尝试对未备份的符号整数执行任何操作,大多数 PyTorch 程序将立即失败。以下是使此功能真正起作用的最重要的增强功能。
在张量创建时,PyTorch 会预先计算有关张量的许多数据;例如,如果您使用
empty_strided
创建张量,我们将对其步幅进行排序,并确定张量是否不重叠且密集。排序会产生大量保护。然而,更常见的是使用像empty
这样的更高级 API 直接生成张量,该 API 保证生成不重叠且密集的张量。我们修改了 PyTorch 以避免不必要地重新计算这些属性。即使需要进行复杂的计算,有时某个属性根本不会被查询。使这些预计算的属性惰性化,可以使我们避免对未备份的符号整数进行保护,除非确实需要。
整数张量中的数据通常未知是否为非负数。但是,我们提供了一个 API
constrain_range
,用户可以通过它指定大小被已知限制的上界和下界。
与动态 API 类似,存在相应的未备份 API:即您可以使用 mark_unbacked 代替 mark_dynamic
,并使用 TORCH_COMPILE_UNBACKED_SOURCES
代替 TORCH_COMPILE_DYNAMIC_SOURCES
来告诉编译器将输入标记为未备份。
在 PT2 的未来版本(PT2.1 之后)中,我们将扩展我们的推理系统,以根据用法推断未备份的符号整数是大小类型的。例如,如果您将 .item()
调用的结果传递给像 torch.empty
这样的工厂函数,我们将自动推断结果是大小(因为如果不是,它会失败)。此假设将在运行时进行验证,如果未满足,将引发错误。