评价此页

动态形状#

创建日期:2023年5月19日 | 最后更新日期:2025年6月10日

代码: symbolic_shapes.py

另请参阅: 动态形状手册

动机#

深度学习编译器通常只适用于静态形状,也就是说,它们生成的编译程序只适用于单一的特定输入形状配置,并且在任何输入形状发生变化时都必须重新编译。这个假设对于当今绝大多数常用的深度学习模型来说效果很好,但在少数情况下是不够的。

  • 某些维度,例如批次大小或序列长度,可能会有所不同。例如,执行自适应批次的推理服务将根据其批次窗口内收到的请求数量,以不同的批次大小执行推理请求。我们可能还希望仅将可变大小的序列填充到批次内的最大序列长度,而这个最大序列长度可能因批次而异。

  • 某些模型表现出数据依赖的输出形状,也就是说,其输出和中间张量的尺寸可能取决于实际输入数据,而这些数据在不同运行之间可能会有所不同。例如,检测模型可能首先生成可变数量的潜在边界框,然后再运行一个更昂贵的图像识别模型来确定主题是否在边界框内。边界框的数量是数据依赖的。

  • 数据依赖形状的一个特别重要的案例发生在处理稀疏表示时,例如稀疏张量、锯齿状张量和图神经网络。在所有这些情况下,要处理的数据量取决于问题的稀疏结构,而这通常会以数据依赖的方式变化。

在支持动态形状时,我们选择不支持动态秩程序,例如,输入张量的维度会发生变化的程序,因为这种模式在实际深度学习程序中很少出现,而且它避免了对形状符号列表进行归纳推理的需要。

精简公共API#

PyTorch 2.1 中的默认动态行为是

  • PT2 默认假定所有内容都是静态的

  • 如果因为某个尺寸发生变化而重新编译,我们将尝试将该尺寸作为动态尺寸重新编译(已更改的尺寸很可能在将来继续变化)。这种泛化可能会失败(例如,因为用户代码对所讨论的尺寸进行了条件分支,或者 PT2 中缺少动态形状支持)。如果您想了解 PT2 为什么对某些代码进行了过度特化,请运行 TORCH_LOGS=dynamic 并查找显示何时添加 guard 以及原因的“eval”条目。

  • 如果您事先知道某个张量将是动态的,您可以使用 torch._dynamo.mark_dynamic(tensor, dim) 跳过第一次重新编译。如果您事先知道此维度可以取 minmax 值,您可以指定 torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)

  • 如果您指定 torch.compile(dynamic=False),我们将关闭自动动态形状的重编译,并始终为每个不同的尺寸进行重编译。反之,如果您指定 torch.compile(dynamic=True),我们将尝试尽可能地使所有内容动态化。这对于小型运算符很有用;如果您在大型模型上尝试此操作,它将(1)很可能导致 PT2 崩溃,并且(2)运行缓慢而没有好理由。

  • 您可以使用 TORCH_COMPILE_DYNAMIC_SOURCES 环境变量或设置 torch.compiler.config.dynamic_sources 来白名单特定源以标记为动态。这对于具有图中断的大型模型特别有用,因为您可以跨图中断保持动态性,因为源名称保持一致。您还可以使用它来标记整数为动态。格式是逗号分隔的源名称列表,例如 "L['x'], L['y']"。您也可以使用正则表达式,例如 "L\['x.*'\], L\['y.*'\]")。此白名单优先于其他标志,如 dynamic=Falseforce_nn_module_property_static_shapesforce_parameter_static_shapes

  • 有时找出要标记为动态的正确输入可能很麻烦。如果您愿意在第一个批次上承受性能损失,我们还有另一个经济实惠的选项是 eager_then_compile 模式,它会为您推导动态性。有关更多详细信息,请参阅 torch.compiler.set_stance

Guard 模型#

在考虑如何为 TorchDynamo 和 TorchInductor 添加动态形状支持时,我们做了一个重要的设计决策:为了重用针对 PyTorch API 编写的分解和其他现有代码,我们必须能够跟踪动态形状。与可能捕获条件的两个分支的完全符号化系统不同,我们总是选择一个分支,并在假设我们将来仅在对该分支做出相同选择时使用此跟踪的假设下特化我们的跟踪。为此,我们为每个符号化大小维护一个“提示”,说明其在编译时的具体值(由于 TorchDynamo 是即时编译器,它总是知道实际的输入大小。)当我们对张量进行条件判断时,我们只需查阅提示即可确定采取哪个分支。

这极大地简化了我们生成的符号形状公式,但意味着我们有一个更复杂的 guard 管理系统。例如,考虑以下程序

def f(x, y):
    z = torch.cat([x, y])
    if z.size(0) > 2:
        return z.mul(2)
    else:
        return z.add(2)

我们将用 TorchInductor 编译的最终 IR 将是 torch.cat([x, y]).add(2)torch.cat([x, y]).mul(2)(条件已展平),但要确定我们在哪个分支,我们需要知道 z 的大小,这是一个中间值。由于 TorchDynamo 必须预先知道编译跟踪是否有效(我们不支持像某些 JIT 编译器那样的退出),我们必须能够将 z.size(0) 表示为输入 x.size(0) + y.size(0) 的表达式。这是通过为 PyTorch 中的所有运算符编写元函数来实现的,这些元函数可以在不实际对节点执行计算的情况下将大小信息传播到张量的输出。

总体架构#

符号形状工作流程

  1. 当我们在 Dynamo 中开始编译一个帧时,我们会分配一个 ShapeEnv(附加到 FakeTensorMode),它会跟踪符号形状状态。

  2. 我们在进入时为张量分配符号大小(静态还是动态是策略决策,有一些调整)。

  3. 我们通过运算符传播符号大小,同时维护(1)FX IR 以便我们能够忠实地导出符号计算,以及(2)表示大小变量的 Sympy 表达式,以便我们能够对它们进行推理。

  4. 当我们根据符号大小进行条件判断时,无论是在 Dynamo 跟踪还是在 Inductor 优化中,我们都会根据条件添加 guard。这些 guard 可以从 Python 和 C++ 中诱导。

  5. 这些 guard 可以对符号变量诱导进一步的简化。例如,如果您断言 s0 == 4,我们现在可以将 s0 的所有出现替换为 4

  6. 当我们完成跟踪和优化时,我们将所有这些 guard 安装到编译后的代码中;只有当所有 guard 都评估为 true 时,编译后的代码才是可重用的。

重要文件

  • C++ SymInt API: c10/core/SymInt.hSymFloat.hSymBool.h

  • Python SymInt API: torch/__init__.py(查找 SymInt/SymFloat/SymBool

  • C++ 管道: c10/core/SymNodeImpl.htorch/csrc/utils/python_symnode.htorch/csrc/jit/python/init.cpp

  • Python 基础设施: torch/fx/experimental/symbolic_shapes.py

  • 其他重要文件: torch/_subclasses/fake_tensor.pytorch/_meta_registrations.py、分解、PrimTorch 引用

精简内部 API#

理解 Python 类层次结构

  • SymInt/SymFloat/SymBool:这些是用户可见的类,模拟其对应的 int/float/bool。如果您将两个 SymInt 相加,我们会给您一个新的 SymInt,它会符号化地跟踪整数加法已发生。

  • SymNode:这是内部结构(可通过例如 symint.node 访问),它保存实际的符号跟踪信息。SymNode 是类型擦除的;这使得表示混合类型操作更加方便。请注意,技术上您不必从 SymInt 调用 Python SymNode;例如,XLA 的 C++ SymNodeImpl 将取代 SymNode。

  • ShapeEnv:每次编译的上下文状态,它跟踪我们到目前为止累积的所有自由符号和 guard。每个 SymNode 都记录其 ShapeEnv(但反之则不然;只有当 SymNode 参与 guard 时才会使用它们)。

C++ 也很相似

  • c10::SymInt/SymFloat/SymBool:用户可见的类,模拟 int/float/bool。

  • c10::SymNode/SymNodeImpl:类似于 SymNode

  • C++ 中没有 ShapeEnv;为了方便调试,整个符号推理机制都在 Python 中。

当您编写可以用 make_fx 跟踪的代码时,它必须能够处理其中的 SymInt/SymFloat/SymBool 流。 动态形状手册 提供了一些关于如何执行此操作的指导。

DimDynamic 策略#

符号推理

  • 值范围

  • Sympy 用法说明

  • Constraints

  • DimDynamic/Constraint

未备份的 SymInt#

为了解析控制流,我们检查符号整数的提示(即实际值)来确定要转到哪个分支。但是,在某些情况下,我们可能没有提示:所谓的未备份符号整数是在数据依赖操作(如 .nonzero().item())中出现的大小变量。对这些符号整数执行控制流是非法的,因此我们必须在这些操作上进行图中断。

如果 naively 实现,这过于严格:大多数 PyTorch 程序在尝试对未备份的符号整数执行任何操作时都会立即失败。以下是对使其真正工作的最重要的增强功能:

  • 在张量创建时,PyTorch 会预先计算有关张量的许多数据;例如,如果您使用 empty_strided 创建张量,我们会主动排序跨步并确定张量是否不重叠且密集。排序会产生大量 guard。但是,更常见的是直接使用像 empty 这样的更高级 API 来生成张量,后者保证会生成非重叠且密集的张量。我们修改了 PyTorch 以避免不必要地重新计算这些属性。

  • 即使需要非平凡的计算,有时某个属性根本不会被查询。将这些预先计算的属性设为惰性使我们能够避免在未备份的符号整数上设置 guard,除非确实需要。

  • 整数张量中的数据通常不确定是否为非负数。但是,我们提供了一个 API constrain_range,用户可以通过它指定大小的上下界由已知限制。

与动态 API 类似,存在相应的未备份 API:即您可以使用 mark_unbacked 而不是 mark_dynamic,并使用 TORCH_COMPILE_UNBACKED_SOURCES 而不是 TORCH_COMPILE_DYNAMIC_SOURCES 来告诉编译器将输入标记为未备份。

在 PT2 的未来版本(PT2.1 之后)中,我们将扩展我们的推理系统,以根据用法推断未备份的符号整数是 size-like 的。例如,如果您将 .item() 调用结果传递给像 torch.empty 这样的工厂函数,我们将自动推断结果是 size(因为如果不是,它将失败。)此假设将在运行时得到验证,如果未满足,将引发错误。