动态形状核心概念#
创建日期: 2025年9月22日 | 最后更新日期: 2025年12月3日
本节描述了 PyTorch 中动态形状的核心概念。它旨在作为 PyTorch 编译器堆栈工程师以及任何希望了解动态形状内部工作原理的人员的参考。
符号整数#
符号整数 (Symints) 用于表示可以跨越范围的变量。例如
x = torch.randn(5, 5) # this tensor has a shape [5, 5]
torch._dynamo.decorators.mark_dynamic(x, 0)
x = torch.randn(5, 5) # this tensor has a shape [s0, 5]
y = torch.cat([x, x], dim=0) # this tensor has a shape [2*s0, 5]
但是,z = x * y 会引发错误,因为我们知道像乘法这样的逐点运算必须在大小相同的张量上进行操作,但我们静态地知道 s0 != 2 * s0。细心的读者可能会指出,当 s0 == 0 时,这并不成立,而此处不影响的原因在 零一特化问题 中有描述。
守卫#
在 torch.compile 中,守卫是一种用于确保编译代码图有效性的机制。默认情况下,当您使变量动态化时,它的范围可以是 [-inf, inf]。例如
def foo(x): return x / 2
This works for any dynamic x. But if your code is:
def foo(x)
if x > 5:
return x / 2
return x / 3
如果您调用 foo(6),它将返回 x / 2 并添加一个守卫 x > 5。稍后调用 foo(4) 将需要重新编译,因为守卫被破坏了。
运行时断言#
您可以使用运行时断言来提供提示,当您知道某些事实时,例如批大小小于 100
def foo(batch_size):
torch._check(batch_size < 100)
if batch_size < 100:
return do_something
return do_something_else()
“提示”值#
在 torch.compile 的上下文中,“提示”值是指在编译过程中已知的值,这些值有助于 JIT 编译器做出关于表达式的决策。提示值对于处理动态形状特别有用,因为它们提供了具体的指导编译的信息,而无需为不同的维度进行重新编译。
动态行为概述#
PyTorch 默认假定静态形状。当检测到大小更改时,它会尝试使用动态输入重新编译,但这可能会失败,如果存在条件分支或缺少对动态形状的支持。要诊断过度特化,您可以设置 TORCH_LOGS=dynamic 以查看“eval”条目,这些条目指示何时以及为什么添加守卫。
如果您预计某个维度将是动态的,可以使用 torch._dynamo.mark_dynamic(tensor, dim) 提前标记它,如果已知,则指定 min 和 max 值。使用 torch.compile(dynamic=False) 会禁用自动动态形状,从而导致为每个唯一大小重新编译。相反,torch.compile(dynamic=True) 旨在尽可能使用动态形状,这对于小型模型最有用,并且可能不适合大型模型,因为可能发生崩溃或性能问题。
您可以使用 TORCH_COMPILE_DYNAMIC_SOURCES 环境变量或 torch.compiler.config.dynamic_sources 来将特定源列入白名单,以标记为动态。这对于具有图中断的大型模型特别有用,因为您可以保持跨图中断的动态性,因为源名称保持一致。您还可以使用它来将整数标记为动态。格式是逗号分隔的源名称列表,例如 "L['x'], L['y']"。您还可以使用正则表达式,例如 "L\['x.*'\], L\['y.*'\]")。此白名单优先于其他标志,例如 dynamic=False force_nn_module_property_static_shapes 和 force_parameter_static_shapes。
有时,找到要标记为动态的正确输入可能很麻烦。如果您愿意承担第一批次的性能损失,我们还有另一种经济的选择,那就是 eager_then_compile 姿态,它可以为您推导出动态性。有关更多详细信息,请参阅 torch.compiler.set_stance()。
总体架构#
符号形状工作流
在 Dynamo 中编译框架时,我们会分配一个
ShapeEnv(附加到FakeTensorMode)来跟踪符号形状。我们根据策略决策在进入时为张量分配符号大小。
我们传播符号大小通过运算符,维护 FX IR 用于符号计算导出和 Sympy 表达式用于推理。
我们在 Dynamo 跟踪或 Inductor 优化期间基于条件添加守卫,由 Python 和 C++ 引起。
守卫可以简化符号变量。例如,断言
s0 == 4允许将所有出现的s0替换为4。在跟踪和优化之后,我们将所有守卫与编译的代码一起安装,仅当所有守卫评估为 true 时才确保可重用性。
内部 API 类层次结构#
Python 类#
SymInt/SymFloat/SymBool:用户可见的类,模拟它们的int/float/bool对应物。添加两个SymInts会产生一个新的SymInt,该SymInt会符号地跟踪整数加法。SymNode:内部结构(可通过symint.node访问),保存实际的符号跟踪信息。SymNode被类型擦除,使其方便地表示混合类型操作。ShapeEnv:每个编译上下文状态,跟踪到目前为止的所有自由符号和累积的守卫。每个SymNode记录其ShapeEnv(但不反过来;SymNodes仅在参与守卫时才使用)。
C++ 等效项#
c10::SymInt/SymFloat/SymBool:用户可见的类,模拟int/float/boolc10::SymNode/SymNodeImpl:类似于 PythonSymNode没有 C++
ShapeEnv:为了方便调试,整个符号推理装置都保留在 Python 中
在编写可使用 make_fx 跟踪的代码时,它必须处理流经它的 SymInt/SymFloat/SymBool。
值范围和约束#
符号变量维护值范围,用于指定可能值的集合。默认情况下
类似大小的未支持的
SymInts的值范围为[0, Inf]常规未支持的
SymInts的值范围为[-Inf, Inf]
当做出断言时(例如,torch._check(x == y)),系统
尝试使用等效表达式替换未支持的符号
根据断言细化值范围
记住始终为真的布尔表达式
重要文件
C++ SymInt API:
c10/core/SymInt.h,SymFloat.h,SymBool.hPython SymInt API:
torch/__init__.py(查找SymInt/SymFloat/SymBool)C++ 基础组件:
c10/core/SymNodeImpl.h,torch/csrc/utils/python_symnode.h,torch/csrc/jit/python/init.cppPython 基础设施:
torch/fx/experimental/symbolic_shapes.py其他重要文件:
torch/_subclasses/fake_tensor.py,torch/_meta_registrations.py,decomps,PrimTorch 引用