评价此页

带描述符的联合#

创建于:2025 年 8 月 11 日 | 最后更新于:2025 年 8 月 11 日

带描述符的联合 (Joint with descriptors) 是一个实验性 API,用于导出支持 `torch.compile` 所有功能的、最通用的追踪联合图 (traced joint graph),并且在处理后可以转换回可微分的可调用对象 (differentiable callable),以正常方式执行。例如,它用于实现 autoparallel,这是一个接受模型并重新划分输入和参数以使其成为分布式 SPMD 程序的系统。

torch._functorch.aot_autograd.aot_export_joint_with_descriptors(stack, mod, args, kwargs=None, *, decompositions=None, keep_inference_input_mutations=False, ignore_shape_env=False, fw_compiler=<function boxed_nop_preserve_node_meta>, bw_compiler=<function boxed_nop_preserve_node_meta>)[source]#

此 API 捕获 `nn.Module` 的联合图。然而,与 `aot_export_joint_simple` 或 `aot_export_module(trace_joint=True)` 不同,生成的联合图的调用约定不遵循固定的位置模式;例如,您不能依赖于追踪联合图的第二个参数对应于您追踪的模块的第二个参数。但是,追踪图的输入和输出是用 **描述符 (descriptors)** 进行模式化的,这些描述符标注在占位符和返回的 FX 节点上的 `meta['desc']` 中,您可以使用它们来确定参数的含义。

与 `aot_export_joint_simple` 相比,使用此导出的主要好处是,我们拥有 `torch.compile` 支持的所有情况(通过 `aot_module_simplified`)的功能对等性,包括处理更复杂的情况,例如多个可微分输出、必须在图外部处理的输入变异、张量子类等。

您可以使用带有描述符的联合图做什么?其主要用例(autoparallel)涉及获取联合图,对其进行优化,然后将其转换回可调用对象,以便稍后可以进行 `torch.compile`。由于两个原因,这不能作为传统的 `torch.compile` 联合图传递完成:

  1. 参数的切分 (sharding) 必须在参数初始化/检查点加载之前决定,这远早于 `torch.compile` 通常运行的时间。

  2. 我们需要改变参数的含义(例如,我们可能会用分片版本替换复制参数,从而改变其输入大小)。`torch.compile` 通常是语义保留的,不允许更改输入的含义。

一些描述符可能相当奇异,因此我们建议仔细考虑是否存在一个安全的后备方案可以应用于您不理解的描述符。例如,您应该有一些方法来处理在最终 FX 图输入中找不到完全相同的特定输入的情况。

注意:使用此 API 时,您必须创建并进入 `ExitStack` 上下文管理器,并将其传递给此函数。如果您调用 `compile` 函数来完成编译,则此上下文管理器必须保持活动状态。(TODO:我们可能会放宽此要求,让 AOTAutograd 能够跟踪如何稍后重建所有上下文管理器。)

注意:您不必在第二阶段执行 /完整的/ 编译;相反,您可以不指定前向/后向编译器,在这种情况下,分区后的 FX 图将直接运行。整体 `autograd.Function` 可以保留在图中,以便您可以在(可能更大)已编译区域的上下文中稍后重新处理它。

注意:这些 API **不** 命中缓存,因为我们只缓存最终的编译结果,而不缓存中间导出结果。

注意:如果传入的 `nn.Module` 具有参数和缓冲区,我们将生成额外的隐式参数/缓冲区参数,并为其分配 `ParamAOTInput` 和 `BufferAOTInput` 描述符。但是,如果您从 Dynamo 等机制生成输入 `nn.Module`,则不会得到这些描述符(因为 Dynamo 已经处理了将参数/缓冲区提升为参数!)。在这种情况下,有必要分析输入的 `Sources` 以确定输入是否是参数及其 FQN。

返回类型

JointWithDescriptors

torch._functorch.aot_autograd.aot_compile_joint_with_descriptors(jd)[source]#

与 `aot_export_joint_with_descriptors` 配套的函数,它将联合图编译成一个遵循标准调用约定的可调用函数。`params_flat` 都是参数。

注意:我们 **不** 实例化模块;这为您提供了子类化并自定义其行为的灵活性,而无需担心 FQN 重新绑定。

TODO:考虑我们是否应默认允许在图中 (allow_in_graph) 返回结果。

返回类型

callable

描述符#

class torch._functorch._aot_autograd.descriptors.AOTInput[source]#

描述来自 AOTAutograd 生成的 FX 图的输入的来源

is_buffer()[source]#

如果此输入是缓冲区或派生自缓冲区(例如,子类属性),则为 True

返回类型

布尔值

is_param()[source]#

如果此输入是参数或派生自参数(例如,子类属性),则为 True

返回类型

布尔值

is_tangent()[source]#

如果此输入是切线 (tangent) 或派生自切线(例如,子类属性),则为 True

返回类型

布尔值

class torch._functorch._aot_autograd.descriptors.AOTOutput[source]#

描述 AOTAutograd 生成的 FX 图的输出最终将如何打包到最终输出中

is_grad()[source]#

如果此输出是梯度或派生自梯度(例如,子类属性),则为 True

返回类型

布尔值

class torch._functorch._aot_autograd.descriptors.BackwardTokenAOTInput(idx)[source]#

用于反向传播的副作用操作的世界令牌 (world token)

class torch._functorch._aot_autograd.descriptors.BackwardTokenAOTOutput(idx)[source]#

副作用调用的世界令牌输出,返回以便我们不会对其进行 DCE(死代码消除),仅用于反向传播

class torch._functorch._aot_autograd.descriptors.BufferAOTInput(target)[source]#

输入是缓冲区,其 FQN 为 target

class torch._functorch._aot_autograd.descriptors.DummyAOTInput(idx)[source]#

在某些情况下,我们希望调用一个期望 `AOTInput` 的函数,但我们实际上并不关心该逻辑(最典型的是,因为某些代码同时用于编译时和运行时;在此情况下不需要 `AOTInput` 处理)。在这种情况下传入一个 dummy;但最好是有一个根本没有这个的函数版本。

class torch._functorch._aot_autograd.descriptors.DummyAOTOutput(idx)[source]#

在您实际上不关心描述符传播的情况下,请勿在正常情况下使用。

class torch._functorch._aot_autograd.descriptors.GradAOTOutput(grad_of)[source]#

一个输出,表示在联合图中为可微分输入计算出的梯度

class torch._functorch._aot_autograd.descriptors.InputMutationAOTOutput(mutated_input)[source]#

输入的变异值,返回以便我们能够适当地传播自动微分。

class torch._functorch._aot_autograd.descriptors.IntermediateBaseAOTOutput(base_of)[source]#

多个别名(aliasing)输出的中间基。我们只报告一个贡献给该基的输出

class torch._functorch._aot_autograd.descriptors.ParamAOTInput(target)[source]#

输入是参数,其 FQN 为 target

class torch._functorch._aot_autograd.descriptors.PhiloxBackwardBaseOffsetAOTInput[source]#

功能化的 Philox RNG 调用的偏移量,专用于后向图。

class torch._functorch._aot_autograd.descriptors.PhiloxBackwardSeedAOTInput[source]#

功能化的 Philox RNG 调用的种子,专用于后向图。

class torch._functorch._aot_autograd.descriptors.PhiloxForwardBaseOffsetAOTInput[source]#

功能化的 Philox RNG 调用的偏移量,专用于前向图。

class torch._functorch._aot_autograd.descriptors.PhiloxForwardSeedAOTInput[source]#

功能化的 Philox RNG 调用的种子,专用于前向图。

class torch._functorch._aot_autograd.descriptors.PhiloxUpdatedBackwardOffsetAOTOutput[source]#

功能化 RNG 调用的最终偏移量,仅用于后向传播

class torch._functorch._aot_autograd.descriptors.PhiloxUpdatedForwardOffsetAOTOutput[source]#

功能化 RNG 调用的最终偏移量,仅用于前向传播

class torch._functorch._aot_autograd.descriptors.PlainAOTInput(idx)[source]#

输入是普通输入,对应于特定的位置索引。

注意,`AOTInput` 始终相对于具有 **扁平** 调用约定的函数(例如 `aot_module_simplified` 接受的)。有一些 AOTAutograd API 会扁平化 pytrees,我们不记录扁平化中的 PyTree 键路径(但我们应该能够!)。

class torch._functorch._aot_autograd.descriptors.PlainAOTOutput(idx)[source]#

输出元组位置 `idx` 处的普通张量输出

class torch._functorch._aot_autograd.descriptors.SavedForBackwardsAOTOutput(idx: int)[source]#
class torch._functorch._aot_autograd.descriptors.SubclassGetAttrAOTInput(base, attr)[source]#

子类输入在进入 FX 图之前会解包成其组成部分。这告诉您此输入对应于子类(原始子类参数)的哪个特定属性。

class torch._functorch._aot_autograd.descriptors.SubclassGetAttrAOTOutput(base, attr)[source]#

此输出将被打包到此位置的子类中

class torch._functorch._aot_autograd.descriptors.SubclassSizeAOTInput(base, idx)[source]#

这个特定的外部大小 SymInt 输入(在维度 idx 处)来自哪个子类。

class torch._functorch._aot_autograd.descriptors.SubclassSizeAOTOutput(base, idx)[source]#

此输出大小将被打包到此位置的子类中

class torch._functorch._aot_autograd.descriptors.SubclassStrideAOTInput(base, idx)[source]#

这个特定的外部步幅 SymInt 输入(在维度 idx 处)来自哪个子类。

class torch._functorch._aot_autograd.descriptors.SubclassStrideAOTOutput(base, idx)[source]#

此输出步幅将被打包到此位置的子类中

class torch._functorch._aot_autograd.descriptors.SyntheticBaseAOTInput(base_of)[source]#

这与 `ViewBaseAOTInput` 类似,但当没有视图是可微分的时,我们会发生这种情况,因此我们无法获取真正的原始视图,而是为了自动微分而构造了一个合成视图。

class torch._functorch._aot_autograd.descriptors.ViewBaseAOTInput(base_of)[source]#

当多个可微分输入是同一输入的视图时,AOTAutograd 会将这些视图替换为单个表示基的输入。如果您不希望这样,可以在将视图示例输入传递给 AOTAutograd 之前克隆它们。

TODO:原则上,我们可以报告所有贡献给此基的输入。

FX 工具#

此模块包含用于处理 AOTAutograd 生成的带描述符的联合 FX 图的实用函数。它们**不**适用于通用 FX 图。另请参阅 torch._functorch.aot_autograd.aot_export_joint_with_descriptors()。我们还建议阅读 :mod:torch._functorch._aot_autograd.descriptors`。

torch._functorch._aot_autograd.fx_utils.get_all_input_and_grad_nodes(g)[source]#

给定一个带描述符的联合图(占位符和输出上的 `meta['desc']`),返回每个输入及其对应的梯度输出节点(如果存在)。这些元组存储在一个字典中,该字典由描述输入的 `AOTInput` 描述符索引。

注意:返回 **所有** 前向张量输入,包括不可微分输入(这些输入只有一个 `None` 梯度),因此安全地使用此函数来对所有输入执行操作。(非张量输入,如符号整数、令牌或 RNG 状态,**不** 被此函数遍历。)

参数

g (Graph) – 带描述符的 FX 联合图

返回

一个字典,将每个 `DifferentiableAOTInput` 描述符映射到一个元组,该元组包含: - 输入节点本身 - 梯度(输出)节点(如果存在),否则为 `None`

引发
  • RuntimeError – 如果联合图包含子类张量输入/输出;此

  • API 不支持,因为当涉及子类时,输入和梯度之间不一定存在一对一的对应关系

  • 当涉及子类时。

返回类型

dict[torch._functorch._aot_autograd.descriptors.DifferentiableAOTInput, tuple[torch.fx.node.Node, Optional[torch.fx.node.Node]]]

torch._functorch._aot_autograd.fx_utils.get_all_output_and_tangent_nodes(g)[source]#

从联合图中获取所有输出节点及其对应的切线节点。

与 `get_all_input_and_grad_nodes` 类似,但返回输出节点与其切线节点配对(如果存在)。此函数遍历图以查找所有可微分输出,并将它们与其在正向模式自动微分中使用的相应切线输入进行匹配。

注意:返回 **所有** 前向张量输出,包括不可微分输出,因此您可以使用此函数来对所有输出执行操作。

参数

g (Graph) – 带描述符的 FX 联合图

返回

一个字典,将每个 `DifferentiableAOTOutput` 描述符映射到一个元组,该元组包含: - 输出节点本身 - 切线(输入)节点(如果存在),否则为 `None`

引发
  • RuntimeError – 如果联合图包含子类张量输入/输出;此

  • API 不支持,因为当涉及子类时,输入和梯度之间不一定存在一对一的对应关系

  • 当涉及子类时,输出和切线之间不存在一一对应关系。

返回类型

dict[torch._functorch._aot_autograd.descriptors.DifferentiableAOTOutput, tuple[torch.fx.node.Node, Optional[torch.fx.node.Node]]]

torch._functorch._aot_autograd.fx_utils.get_buffer_nodes(graph)[source]#

将图中的所有缓冲区节点获取为一个列表。

您可以依赖此函数提供您需要馈入联合图(在参数之后)的正确缓冲区顺序。

参数

graph (Graph) – 带描述符的 FX 联合图

返回

表示图中所有缓冲区的 FX 节点列表。

引发
  • RuntimeError – 如果遇到子类张量(尚不支持),因为

  • 不清楚您是否想要子类的每个单独的组成部分

  • 还是希望将它们以某种方式分组。

返回类型

list[torch.fx.node.Node]

torch._functorch._aot_autograd.fx_utils.get_named_buffer_nodes(graph)[source]#

按完全限定名称映射缓冲区节点。

此函数遍历图以查找所有缓冲区输入节点,并返回一个字典,其中键是缓冲区名称 (FQN),值是相应的 FX 节点。

参数

graph (Graph) – 带描述符的 FX 联合图

返回

将缓冲区名称 (str) 映射到其相应 FX 节点的字典。

引发
  • RuntimeError – 如果遇到子类张量(尚不支持),因为

  • 对于子类,FQN 不一定映射到一个普通张量。

返回类型

dict[str, torch.fx.node.Node]

torch._functorch._aot_autograd.fx_utils.get_named_param_nodes(graph)[source]#

按完全限定名称映射参数节点。

此函数遍历图以查找所有参数输入节点,并返回一个字典,其中键是参数名称 (FQN),值是相应的 FX 节点。

参数

graph (Graph) – 带描述符的 FX 联合图

返回

将参数名称 (str) 映射到其相应 FX 节点的字典。

引发
  • RuntimeError – 如果遇到子类张量(尚不支持),因为

  • 对于子类,FQN 不一定映射到一个普通张量。

返回类型

dict[str, torch.fx.node.Node]

torch._functorch._aot_autograd.fx_utils.get_param_and_grad_nodes(graph)[source]#

从联合图中获取参数节点及其对应的梯度节点。

参数

graph (Graph) – 带描述符的 FX 联合图

返回

  • 参数输入节点

  • 梯度(输出)节点(如果存在),否则为 `None`

返回类型

一个字典,将每个 `ParamAOTInput` 描述符映射到一个元组,该元组包含

torch._functorch._aot_autograd.fx_utils.get_param_nodes(graph)[source]#

将图中的所有参数节点获取为一个列表。

您可以依赖此函数提供您需要馈入联合图(在参数列表的开头,在缓冲区之前)的正确参数顺序。

参数

graph (Graph) – 带描述符的 FX 联合图

返回

表示图中所有参数的 FX 节点列表。

引发
  • RuntimeError – 如果遇到子类张量(尚不支持),因为

  • 不清楚您是否想要子类的每个单独的组成部分

  • 还是希望将它们以某种方式分组。

返回类型

list[torch.fx.node.Node]

torch._functorch._aot_autograd.fx_utils.get_plain_input_and_grad_nodes(graph)[source]#

从联合图中获取普通输入节点及其对应的梯度节点。

参数

graph (Graph) – 带描述符的 FX 联合图

返回

  • 普通输入节点

  • 梯度(输出)节点(如果存在),否则为 `None`

返回类型

一个字典,将每个 `PlainAOTInput` 描述符映射到一个元组,该元组包含

torch._functorch._aot_autograd.fx_utils.get_plain_output_and_tangent_nodes(graph)[source]#

从联合图中获取普通输出节点及其对应的切线节点。

参数

graph (Graph) – 带描述符的 FX 联合图

返回

  • 普通输出节点

  • 切线(输入)节点(如果存在),否则为 `None`

返回类型

一个字典,将每个 `PlainAOTOutput` 描述符映射到一个元组,该元组包含