torch.export API 参考#
创建日期:2025年7月17日 | 最后更新日期:2025年12月3日
- torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=False, preserve_module_call_signature=(), prefer_deferred_runtime_asserts_over_guards=False)[source]#
export()接收任意 nn.Module 及示例输入,并以提前(AOT, Ahead-of-Time)方式生成一个仅代表函数张量计算的跟踪图(traced graph),该图随后可用于不同输入执行或序列化。此跟踪图 (1) 生成函数化 ATen 算子集(以及任何用户指定的自定义算子)中的规范化算子;(2) 消除了所有 Python 控制流和数据结构(特定例外除外);(3) 记录了为了确保这种规范化和控制流消除对未来输入依然有效所需的形状约束集合。可靠性保证
在跟踪过程中,
export()会记录用户程序和底层 PyTorch 算子内核所做的形状相关假设。输出的ExportedProgram仅在这些假设成立时才被认为是有效的。跟踪过程会对输入张量的形状(而非数值)做出假设。这些假设必须在图捕获时进行验证,
export()才能成功。具体而言:对输入张量静态形状的假设会自动验证,无需额外操作。
对输入张量动态形状的假设需要明确指定:使用
Dim()API 构建动态维度,并通过dynamic_shapes参数将其与示例输入关联起来。
如果任何假设无法验证,将抛出致命错误。发生这种情况时,错误消息将包含验证假设所需的建议修正方案。例如,
export()可能会建议对动态维度dim0_x的定义进行以下修复,该维度出现在输入x的形状中,且之前定义为Dim("dim0_x")。dim = Dim("dim0_x", max=5)
此示例意味着生成的代码要求输入
x的第 0 维必须小于或等于 5 才能生效。您可以检查这些建议的动态维度定义修正方案,并将它们直接复制到您的代码中,而无需更改export()调用中的dynamic_shapes参数。- 参数:
mod (Module) – 我们将跟踪此模块的 forward 方法。
dynamic_shapes (dict[str, Any] | tuple[Any, ...] | list[Any] | None) –
一个可选参数,其类型应为:1) 从
f的参数名称到其动态形状规范的字典;2) 一个元组,按原始顺序为每个输入指定动态形状规范。如果您要为关键字参数指定动态性,则需要按照原始函数签名中定义的顺序传入它们。张量参数的动态形状可以指定为:(1) 从动态维度索引到
Dim()类型的字典(无需在此字典中包含静态维度索引,但若包含,应映射为 None);或 (2)Dim()类型或 None 的元组/列表,其中Dim()类型对应动态维度,静态维度用 None 表示。作为张量字典或元组/列表的参数,可以通过使用包含规范的映射或序列进行递归指定。strict (bool) – 当禁用(默认)时,导出函数将通过 Python 运行时跟踪程序,这本身不会验证图中固有的某些隐含假设。它仍会验证诸如形状安全等最关键的假设。当启用(设置
strict=True)时,导出函数将通过 TorchDynamo 跟踪程序,从而确保生成图的可靠性。TorchDynamo 对 Python 特性的覆盖范围有限,因此您可能会遇到更多错误。请注意,切换此参数不会改变生成的 IR 规范,无论此处传递什么值,模型都将以相同方式序列化。preserve_module_call_signature (tuple[str, ...]) – 一系列子模块路径,这些路径的原始调用约定将被保留作为元数据。调用
torch.export.unflatten时将使用此元数据,以恢复模块的原始调用约定。
- 返回:
一个包含跟踪后可调用对象的
ExportedProgram。- 返回类型:
可接受的输入/输出类型
可接受的输入(用于
args和kwargs)和输出类型包括:原始类型,即
torch.Tensor,int,float,bool和str。Dataclasses,但必须首先调用
register_dataclass()进行注册。包含上述所有类型的(嵌套)数据结构,包括
dict,list,tuple,namedtuple和OrderedDict。
- class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[source]#
由
export()生成的程序包。它包含一个表示张量计算的torch.fx.Graph,一个包含所有已提升参数和缓冲区的张量值的 state_dict,以及各种元数据。您可以像调用
export()跟踪的原始可调用对象一样,使用相同的调用约定来调用 ExportedProgram。要对图执行转换,请使用
.module属性访问torch.fx.GraphModule。然后可以使用 FX 转换重写图。之后,您可以简单地再次使用export()来构建正确的 ExportedProgram。- property call_spec#
警告
此 API 处于实验阶段,不保证向后兼容。
- property constants#
警告
此 API 处于实验阶段,不保证向后兼容。
- property example_inputs#
警告
此 API 处于实验阶段,不保证向后兼容。
- property graph#
警告
此 API 处于实验阶段,不保证向后兼容。
- property graph_module#
警告
此 API 处于实验阶段,不保证向后兼容。
- property graph_signature#
警告
此 API 处于实验阶段,不保证向后兼容。
- module(check_guards=True)[source]#
返回一个包含所有参数/缓冲区的自包含 GraphModule。
当 check_guards=True(默认)时,将生成一个 _guards_fn 子模块,并在图中的占位符之后立即插入对该 _guards_fn 子模块的调用。此模块会检查输入的守卫(guards)。
当 check_guards=False 时,其中一部分检查将由图模块上的前向预钩(forward pre-hook)执行。不会生成 _guards_fn 子模块。
- 返回类型:
- property module_call_graph#
警告
此 API 处于实验阶段,不保证向后兼容。
- property range_constraints#
警告
此 API 处于实验阶段,不保证向后兼容。
- run_decompositions(decomp_table=None, decompose_custom_triton_ops=False)[source]#
对导出的程序运行一组分解,并返回一个新的导出程序。默认情况下,我们将运行核心 ATen 分解以获得 Core ATen 算子集中的算子。
目前,我们不对联合图(joint graphs)进行分解。
- 参数:
decomp_table (dict[OperatorBase, Callable] | None) – 指定 Aten 算子分解行为的可选参数:(1) 若为 None,我们执行核心 Aten 分解;(2) 若为空字典,我们不对任何算子进行分解。
- 返回类型:
一些示例:
如果您不想分解任何内容:
ep = torch.export.export(model, ...) ep = ep.run_decompositions(decomp_table={})
如果您想获取除某些算子外的核心 aten 算子集,您可以执行以下操作:
ep = torch.export.export(model, ...) decomp_table = torch.export.default_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table)
- property state_dict#
警告
此 API 处于实验阶段,不保证向后兼容。
- property tensor_constants#
警告
此 API 处于实验阶段,不保证向后兼容。
- property verifiers#
警告
此 API 处于实验阶段,不保证向后兼容。
- class torch.export.dynamic_shapes.AdditionalInputs[source]#
根据额外输入推断 dynamic_shapes。
这对于部署工程师特别有用,因为他们一方面可能拥有丰富的测试或分析数据,能很好地了解模型的代表性输入,但另一方面可能并不了解模型,无法猜测哪些输入形状应该是动态的。
与原始输入不同的输入形状被视为动态的;反之,与原始形状相同的则被视为静态的。此外,我们还会验证额外输入对于导出的程序是否有效。这保证了使用它们而非原始数据进行跟踪时,也会生成相同的图。
示例
args0, kwargs0 = ... # example inputs for export # other representative inputs that the exported program will run on dynamic_shapes = torch.export.AdditionalInputs() dynamic_shapes.add(args1, kwargs1) ... dynamic_shapes.add(argsN, kwargsN) torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes)
- dynamic_shapes(m, args, kwargs=None)[source]#
通过合并原始输入
args()和kwargs()与每个额外输入的 args 和 kwargs 的形状,推断出dynamic_shapes()的 pytree 结构。
- class torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source]#
Dim类允许用户在导出程序中指定动态性。通过用Dim标记一个维度,编译器会将该维度与包含动态范围的符号整数相关联。该 API 有两种使用方式:Dim 提示(即自动动态形状:
Dim.AUTO,Dim.DYNAMIC,Dim.STATIC),或命名 Dims(即Dim("name", min=1, max=2))。Dim 提示提供了最低的导出门槛,用户只需指定维度是动态的、静态的,还是留给编译器决定(
Dim.AUTO)。导出过程将自动推断有关最小/最大范围以及维度间关系的其余约束。示例
class Foo(nn.Module): def forward(self, x, y): assert x.shape[0] == 4 assert y.shape[0] >= 16 return x @ y x = torch.randn(4, 8) y = torch.randn(8, 16) dynamic_shapes = { "x": {0: Dim.AUTO, 1: Dim.AUTO}, "y": {0: Dim.AUTO, 1: Dim.AUTO}, } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)
在此,如果我们将所有
Dim.AUTO的使用替换为Dim.DYNAMIC,导出将引发异常,因为模型将x.shape[0]约束为静态。维度之间更复杂的关系也可能被编译器编码为运行时断言节点,例如
(x.shape[0] + y.shape[1]) % 4 == 0,若运行时输入不满足此类约束,将引发异常。您还可以为 Dim 提示指定最小-最大界限,例如
Dim.AUTO(min=16, max=32),Dim.DYNAMIC(max=64),编译器会在范围内推断其余约束。如果有效范围完全超出了用户指定的范围,将引发异常。命名 Dims 提供了一种更严格的动态性指定方式,如果编译器推断出的约束与用户规范不符,则会引发异常。例如,导出之前的模型时,用户需要使用以下
dynamic_shapes参数:s0 = Dim("s0") s1 = Dim("s1", min=16) dynamic_shapes = { "x": {0: 4, 1: s0}, "y": {0: s0, 1: s1}, } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)
命名 Dims 还允许指定维度之间的关系,最高支持一元线性关系。例如,下述内容表示一个维度是另一个维度的倍数加 4:
s0 = Dim("s0") s1 = 3 * s0 + 4
- class torch.export.dynamic_shapes.ShapesCollection[source]#
dynamic_shapes的构建器。用于为输入中出现的张量分配动态形状规范。这在
args()为嵌套输入结构时特别有用,此时对输入张量进行索引比在dynamic_shapes()规范中复制args()的结构更为简便。示例
args = {"x": tensor_x, "others": [tensor_y, tensor_z]} dim = torch.export.Dim(...) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2} # This is equivalent to the following (now auto-generated): # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]} torch.export(..., args, dynamic_shapes=dynamic_shapes)
要指定整数的动态性,我们需要首先使用 _IntWrapper 包装整数,以便我们为每个整数都有一个“唯一识别标签”。
示例
args = {"x": tensor_x, "others": [int_x, int_y]} # Wrap all ints with _IntWrapper mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[mapped_args["others"][0]] = Dim.DYNAMIC # This is equivalent to the following (now auto-generated): # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [Dim.DYNAMIC, None]} torch.export(..., args, dynamic_shapes=dynamic_shapes)
- dynamic_shapes(m, args, kwargs=None)[source]#
根据
args()和kwargs()生成dynamic_shapes()pytree 结构。
- torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source]#
当使用
dynamic_shapes()进行导出时,如果规范与从模型追踪(tracing)中推断出的约束不匹配,导出可能会失败并抛出 ConstraintViolation 错误。错误消息可能会提供建议的修复方案——即可以对dynamic_shapes()所做的修改,以实现成功导出。示例 ConstraintViolation 错误消息
Suggested fixes: dim = Dim('dim', min=3, max=6) # this just refines the dim's range dim = 4 # this specializes to a constant dy = dx + 1 # dy was specified as an independent dim, but is actually tied to dx with this relation
这是一个辅助函数,它接收 ConstraintViolation 错误消息和原始的
dynamic_shapes()规范,并返回一个新的dynamic_shapes()规范,其中包含了建议的修复方案。使用示例
try: ep = export(mod, args, dynamic_shapes=dynamic_shapes) except torch._dynamo.exc.UserError as exc: new_shapes = refine_dynamic_shapes_from_suggested_fixes( exc.msg, dynamic_shapes ) ep = export(mod, args, dynamic_shapes=new_shapes)
- torch.export.save(ep, f, *, extra_files=None, opset_version=None, pickle_protocol=2)[source]#
警告
由于处于积极开发阶段,保存的文件可能无法在较新版本的 PyTorch 中使用。
将
ExportedProgram保存到类文件对象(file-like object)中。随后可以使用 Python APItorch.export.load进行加载。- 参数:
ep (ExportedProgram) – 要保存的导出程序。
f (str | os.PathLike[str] | IO[bytes]) – 实现 write 和 flush 方法的类文件对象,或包含文件名的字符串。
extra_files (Optional[Dict[str, Any]]) – 从文件名到内容的映射,将作为 f 的一部分存储。
opset_version (Optional[Dict[str, int]]) – opset 名称到该 opset 版本的映射。
pickle_protocol (int) – 可指定以覆盖默认协议。
示例
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file torch.export.save(ep, "exported_program.pt2") # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {"foo.txt": b"bar".decode("utf-8")} torch.export.save(ep, "exported_program.pt2", extra_files=extra_files)
- torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]#
警告
由于处于积极开发阶段,保存的文件可能无法在较新版本的 PyTorch 中使用。
警告
torch.export.load()在底层使用 pickle 来加载模型。切勿从不受信任的来源加载数据。加载之前使用
torch.export.save保存的ExportedProgram。- 参数:
- 返回:
一个
ExportedProgram对象。- 返回类型:
示例
import torch import io # Load ExportedProgram from file ep = torch.export.load("exported_program.pt2") # Load ExportedProgram from io.BytesIO object with open("exported_program.pt2", "rb") as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {"foo.txt": ""} # values will be replaced with data ep = torch.export.load("exported_program.pt2", extra_files=extra_files) print(extra_files["foo.txt"]) print(ep(torch.randn(5)))
- torch.export.pt2_archive._package.package_pt2(f, *, exported_programs=None, aoti_files=None, extra_files=None, opset_version=None, pickle_protocol=2, executorch_files=None)[source]#
将工件保存为 PT2Archive 格式。该工件随后可以使用
load_pt2进行加载。- 参数:
f (str | os.PathLike[str] | IO[bytes]) – 一个类文件对象(必须实现 write 和 flush)或包含文件名的字符串。
exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]) – 要保存的导出程序,或从模型名称到要保存的导出程序的字典映射。导出的程序将保存在 models/*.json 下。如果仅指定了一个 ExportedProgram,它将被自动命名为“model”。
aoti_files (Union[list[str], dict[str, list[str]]]) – 由 AOTInductor 通过
torch._inductor.aot_compile(..., {"aot_inductor.package": True})生成的文件列表,或者从模型名称到其 AOTInductor 生成文件的字典映射。如果仅指定了一组文件,它们将被自动命名为“model”。extra_files (Optional[Dict[str, Any]]) – 从文件名到内容的映射,将作为 pt2 的一部分进行存储。
opset_version (Optional[Dict[str, int]]) – opset 名称到该 opset 版本的映射。
pickle_protocol (int) – 可指定以覆盖默认协议。
executorch_files (Optional[dict[str, bytes]]) – 要保存的可选 ExecuTorch 工件。
- 返回类型:
- torch.export.pt2_archive._package.load_pt2(f, *, expected_opset_version=None, run_single_threaded=False, num_runners=1, device_index=-1, load_weights_from_disk=False)[source]#
加载之前使用
package_pt2保存的所有工件。- 参数:
f (str | os.PathLike[str] | IO[bytes]) – 一个类文件对象(必须实现 write 和 flush)或包含文件名的字符串。
expected_opset_version (Optional[Dict[str, int]]) – opset 名称到预期 opset 版本的映射。
num_runners (int) – 用于加载 AOTInductor 工件的运行器数量。
run_single_threaded (bool) – 模型是否应在没有线程同步逻辑的情况下运行。这对于避免与 CUDAGraphs 冲突很有用。
device_index (int) – 要将 PT2 包加载到的设备索引。默认情况下使用 device_index=-1,在使用 CUDA 时对应于 cuda 设备。例如,传递 device_index=1 会将包加载到 cuda:1。
- 返回:
包含 PT2 中所有对象的
PT2ArchiveContents对象。- 返回类型:
PT2ArchiveContents
- torch.export.draft_export(mod, args, kwargs=None, *, dynamic_shapes=None, preserve_module_call_signature=(), strict=False, prefer_deferred_runtime_asserts_over_guards=False)[source]#
torch.export.export 的一个版本,旨在始终如一地生成 ExportedProgram(即使存在潜在的稳健性问题),并生成一份列出所发现问题的报告。
- 返回类型:
- class torch.export.unflatten.FlatArgsAdapter[source]#
使用
input_spec适配输入参数,以与target_spec对齐。
- class torch.export.unflatten.InterpreterModule(graph, ty=None)[source]#
一个使用 torch.fx.Interpreter 进行执行的模块,而不是使用 GraphModule 通常使用的代码生成。这提供了更好的堆栈跟踪信息,并使执行调试变得更容易。
- class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)[source]#
一个模块,携带一系列 InterpreterModules,对应于该模块的一系列调用。对模块的每次调用都会分发到下一个 InterpreterModule,并在最后一次调用后重新开始循环。
- torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source]#
对 ExportedProgram 进行非扁平化(unflatten)处理,生成一个具有与原始 eager 模块相同模块层级的模块。如果你尝试将
torch.export与其他期望模块层级而不是torch.export通常产生的扁平图的系统一起使用时,这可能会很有用。注意
非扁平化模块的 args/kwargs 不一定与 eager 模块匹配,因此进行模块交换(例如
self.submod = new_mod)不一定会生效。如果你需要交换一个模块,你需要设置torch.export.export()的preserve_module_call_signature参数。- 参数:
module (ExportedProgram) – 要非扁平化的 ExportedProgram。
flat_args_adapter (Optional[FlatArgsAdapter]) – 如果输入 TreeSpec 与导出模块的不匹配,则适配扁平化参数。
- 返回:
UnflattenedModule的一个实例,它具有与导出前的原始 eager 模块相同的模块层级。- 返回类型:
UnflattenedModule
- torch.export.register_dataclass(cls, *, serialized_type_name=None)[source]#
将 dataclass 注册为
torch.export.export()的有效输入/输出类型。- 参数:
示例
import torch from dataclasses import dataclass @dataclass class InputDataClass: feature: torch.Tensor bias: int @dataclass class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) class Mod(torch.nn.Module): def forward(self, x: InputDataClass) -> OutputDataClass: res = x.feature + x.bias return OutputDataClass(res=res) ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),)) print(ep)
- class torch.export.decomp_utils.CustomDecompTable[source]#
这是一个专门用于在导出中处理 decomp_table 的自定义字典。我们需要它的原因是,在新世界中,你只能从 decomp 表中删除一个算子以保留它。这对于自定义算子是有问题的,因为我们不知道自定义算子何时真正加载到分发器(dispatcher)。因此,我们需要记录自定义算子操作,直到我们真正需要将其具体化(即当我们运行分解过程时)。
- 我们维持的不变量是:
所有 aten 分解都在初始化时加载。
当用户读取表时,我们会具体化所有算子,以提高分发器拾取自定义算子的可能性。
如果是写操作,我们不一定具体化。
我们在导出期间(即调用 run_decompositions() 之前)进行最后一次加载。
- torch.export.passes.move_to_device_pass(ep, location)[source]#
将导出的程序移动到给定的设备。
- 参数:
ep (ExportedProgram) – 要移动的导出程序。
location (Union[torch.device, str, Dict[str, str]]) – 要移动到的设备。如果是字符串,则解释为设备名称。如果是字典,则解释为从现有设备到目标设备的映射。
- 返回:
移动后的导出程序。
- 返回类型:
- class torch.export.pt2_archive.PT2ArchiveReader(archive_path_or_buffer)#
用于读取 PT2 归档的上下文管理器。
- class torch.export.pt2_archive.PT2ArchiveWriter(archive_path_or_buffer)#
用于写入 PT2 归档的上下文管理器。
- class torch.export.exported_program.ModuleCallEntry(fqn: str, signature: torch.export.exported_program.ModuleCallSignature | None = None)[source]#
- class torch.export.exported_program.ModuleCallSignature(inputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec, forward_arg_names: list[str] | None = None)[source]#
- torch.export.exported_program.default_decompositions()[source]#
这是默认的分解表,其中包含所有 ATEN 算子到核心 aten 算子集的分解。请将此 API 与
run_decompositions()一起使用。- 返回类型:
- class torch.export.custom_obj.ScriptObjectMeta(constant_name, class_fqn)[source]#
存储在表示 ScriptObject 的节点上的元数据。
- class torch.export.graph_signature.ConstantArgument(name: str, value: int | float | bool | str | None)[source]#
- class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: torch._library.fake_class_registry.FakeScriptObject | None = None)[source]#
- class torch.export.graph_signature.ExportBackwardSignature(gradients_to_parameters: dict[str, str], gradients_to_user_inputs: dict[str, str], loss_output: str)[source]#
- class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source]#
ExportGraphSignature对 Export Graph 的输入/输出签名进行建模,这是一个具有更强不变量保证的 fx.Graph。Export Graph 是函数式的,不会通过
getattr节点访问图内的“状态”(如参数或缓冲区)。相反,export()保证参数、缓冲区和常量张量会作为输入从图中提取出来。同样,缓冲区的任何突变也不会包含在图中,而是将突变缓冲区的更新值建模为 Export Graph 的额外输出。所有输入和输出的顺序为
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如:如果导出以下模块
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer("my_buffer1", torch.tensor(3.0)) self.register_buffer("my_buffer2", torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = ( x1 + self.my_parameter ) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output mod = CustomModule() ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))
结果图是非函数式的
graph(): %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] %x1 : [num_users=1] = placeholder[target=x1] %x2 : [num_users=1] = placeholder[target=x2] %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) return (add_1,)
非函数式图的所得 ExportGraphSignature 将为
# inputs p_my_parameter: PARAMETER target='my_parameter' b_my_buffer1: BUFFER target='my_buffer1' persistent=True b_my_buffer2: BUFFER target='my_buffer2' persistent=True x1: USER_INPUT x2: USER_INPUT # outputs add_1: USER_OUTPUT
要获得函数式图,可以使用
run_decompositions()mod = CustomModule() ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0))) ep = ep.run_decompositions()
结果图是函数式的
graph(): %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] %x1 : [num_users=1] = placeholder[target=x1] %x2 : [num_users=1] = placeholder[target=x2] %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) return (add_2, add_1)
函数式图的所得 ExportGraphSignature 将为
# inputs p_my_parameter: PARAMETER target='my_parameter' b_my_buffer1: BUFFER target='my_buffer1' persistent=True b_my_buffer2: BUFFER target='my_buffer2' persistent=True x1: USER_INPUT x2: USER_INPUT # outputs add_2: BUFFER_MUTATION target='my_buffer2' add_1: USER_OUTPUT
- property backward_signature: ExportBackwardSignature | None#
- property buffers: Collection[str]#
- property input_tokens: Collection[str]#
- property lifted_custom_objs: Collection[str]#
- property lifted_tensor_constants: Collection[str]#
- property non_persistent_buffers: Collection[str]#
- output_specs: list[OutputSpec]#
- property output_tokens: Collection[str]#
- property parameters: Collection[str]#
- class torch.export.graph_signature.InputKind(value)[source]#
一个枚举类型。
- BUFFER = 3#
- CONSTANT_TENSOR = 4#
- CUSTOM_OBJ = 5#
- PARAMETER = 2#
- TOKEN = 6#
- USER_INPUT = 1#
- class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: torch.export.graph_signature.TensorArgument | torch.export.graph_signature.SymIntArgument | torch.export.graph_signature.SymFloatArgument | torch.export.graph_signature.SymBoolArgument | torch.export.graph_signature.ConstantArgument | torch.export.graph_signature.CustomObjArgument | torch.export.graph_signature.TokenArgument, target: str | None, persistent: bool | None = None)[source]#
- class torch.export.graph_signature.OutputKind(value)[source]#
一个枚举类型。
- BUFFER_MUTATION = 3#
- GRADIENT_TO_PARAMETER = 5#
- GRADIENT_TO_USER_INPUT = 6#
- LOSS_OUTPUT = 2#
- PARAMETER_MUTATION = 4#
- TOKEN = 8#
- USER_INPUT_MUTATION = 7#
- USER_OUTPUT = 1#
- class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: torch.export.graph_signature.TensorArgument | torch.export.graph_signature.SymIntArgument | torch.export.graph_signature.SymFloatArgument | torch.export.graph_signature.SymBoolArgument | torch.export.graph_signature.ConstantArgument | torch.export.graph_signature.CustomObjArgument | TokenArgument, target: str | None)[source]#
- arg: TensorArgument | SymIntArgument | SymFloatArgument | SymBoolArgument | ConstantArgument | CustomObjArgument | TokenArgument#
- kind: OutputKind#