torch.export API 参考#
创建于: 2025年7月17日 | 最后更新于: 2025年7月17日
- torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=False, preserve_module_call_signature=(), prefer_deferred_runtime_asserts_over_guards=False)[source]#
export()接收任意 nn.Module 和示例输入,以提前(AOT)方式生成一个仅表示函数张量计算的跟踪图,之后可以使用不同的输入执行或序列化该图。跟踪图 (1) 在函数式 ATen 运算符集中生成规范化的运算符(以及用户指定的任何自定义运算符),(2) 消除了所有 Python 控制流和数据结构(某些例外情况除外),并且 (3) 记录了显示这种规范化和控制流消除对于未来输入是健全的形状约束。健全性保证
在跟踪期间,
export()会记录用户程序和底层 PyTorch 运算符内核所做的与形状相关的假设。只有当这些假设成立时,生成的ExportedProgram才被认为是有效的。跟踪会做出关于输入张量形状(而非值)的假设。为了使
export()成功,这些假设必须在图捕获时进行验证。具体来说:对输入张量静态形状的假设无需额外工作即可自动验证。
对输入张量动态形状的假设需要通过使用
Dim()API 来构造动态维度,并通过dynamic_shapes参数将其与示例输入关联来显式指定。
如果任何假设无法验证,将引发致命错误。发生这种情况时,错误消息将包含对验证假设所需的规范的建议修复。例如,
export()可能会为动态维度dim0_x的定义提出以下修复,该维度出现在输入x的形状中,该维度以前定义为Dim("dim0_x")。dim = Dim("dim0_x", max=5)
此示例意味着生成的代码要求输入
x的维度 0 小于或等于 5 才能有效。您可以检查动态维度定义的建议修复,然后将其逐字复制到您的代码中,而无需更改传递给export()调用的dynamic_shapes参数。- 参数
mod (Module) – 我们将跟踪此模块的 forward 方法。
dynamic_shapes (Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]]) –
一个可选参数,其类型应为:1)一个字典,从
f的参数名称映射到其动态形状规范,2)一个元组,指定按原始顺序排列的每个输入的动态形状规范。如果您正在为关键字参数指定动态性,则需要按照函数原始签名中定义的顺序传递它们。张量参数的动态形状可以指定为:1)一个从动态维度索引到
Dim()类型的字典,其中不需要在此字典中包含静态维度索引,但当它们存在时,应将其映射到 None;或 2)一个Dim()类型或 None 的元组/列表,其中Dim()类型对应动态维度,静态维度由 None 表示。由字典或张量元组/列表组成的参数通过使用包含的规范的映射或序列来递归指定。strict (bool) – 当禁用(默认)时,export 函数将通过 Python 运行时跟踪程序,这本身不会验证图中的一些隐式假设。它仍然会验证大多数关键假设,例如形状安全性。当启用(通过设置
strict=True)时,export 函数将通过 TorchDynamo 跟踪程序,这将确保生成图的健全性。TorchDynamo 对 Python 特性的覆盖有限,因此您可能会遇到更多错误。请注意,切换此参数不会影响生成的 IR 规范,模型将以相同的方式序列化,无论此处传递什么值。preserve_module_call_signature (tuple[str, ...]) – 一个子模块路径列表,将保留其原始调用约定作为元数据。调用 torch.export.unflatten 时将使用元数据来保留模块的原始调用约定。
- 返回
一个包含跟踪的可调用对象的
ExportedProgram。- 返回类型
可接受的输入/输出类型
可接受的输入(用于
args和kwargs)和输出类型包括:基本类型,即
torch.Tensor、int、float、bool和str。数据类,但它们必须首先通过调用
register_dataclass()进行注册。(嵌套的) 数据结构,由
dict、list、tuple、namedtuple和OrderedDict组成,其中包含上述所有类型。
- class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[source]#
来自
export()的程序包。它包含一个torch.fx.Graph,该图表示张量计算,一个包含所有提升的参数和缓冲区张量值的 state_dict,以及各种元数据。您可以使用与
export()跟踪的原始可调用对象相同的调用约定来调用 ExportedProgram。要对图执行转换,请使用 `.module` 属性访问
torch.fx.GraphModule。然后,您可以使用 FX 转换 来重写图。之后,您只需再次使用export()即可构建一个正确的 ExportedProgram。- property call_spec#
警告
此 API 仍处于实验阶段,并且 *不* 向后兼容。
- property constants#
警告
此 API 仍处于实验阶段,并且 *不* 向后兼容。
- property example_inputs#
警告
此 API 仍处于实验阶段,并且 *不* 向后兼容。
- property graph#
警告
此 API 仍处于实验阶段,并且 *不* 向后兼容。
- property graph_module#
警告
此 API 仍处于实验阶段,并且 *不* 向后兼容。
- property graph_signature#
警告
此 API 仍处于实验阶段,并且 *不* 向后兼容。
- module(check_guards=True)[source]#
返回一个自包含的 GraphModule,其中所有参数/缓冲区都已内联。
当 check_guards=True (默认) 时,将生成一个 _guards_fn 子模块,并在图中的占位符之后插入一个对 _guards_fn 子模块的调用。此模块检查输入的 guard。
当 check_guards=False 时,一部分检查将由图模块的 forward pre-hook 执行。不会生成 _guards_fn 子模块。
- 返回类型
- property module_call_graph#
警告
此 API 仍处于实验阶段,并且 *不* 向后兼容。
- property range_constraints#
警告
此 API 仍处于实验阶段,并且 *不* 向后兼容。
- run_decompositions(decomp_table=None, decompose_custom_triton_ops=False)[source]#
对导出的程序运行一组分解,并返回一个新的导出的程序。默认情况下,我们将运行核心 ATen 分解,以在 Core ATen Operator Set 中获得运算符。
目前,我们不分解联合图。
- 参数
decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]) – 一个可选参数,指定 Aten ops 的分解行为 (1) 如果为 None,我们分解为核心 aten 分解 (2) 如果为空,我们不分解任何运算符
- 返回类型
一些例子
如果您不想分解任何内容
ep = torch.export.export(model, ...) ep = ep.run_decompositions(decomp_table={})
如果您想获取核心 aten 运算符集,但排除某些运算符,您可以这样做:
ep = torch.export.export(model, ...) decomp_table = torch.export.default_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table)
- property state_dict#
警告
此 API 仍处于实验阶段,并且 *不* 向后兼容。
- property tensor_constants#
警告
此 API 仍处于实验阶段,并且 *不* 向后兼容。
- property verifiers#
警告
此 API 仍处于实验阶段,并且 *不* 向后兼容。
- class torch.export.dynamic_shapes.AdditionalInputs[source]#
根据附加输入推断 dynamic_shapes。
这对于部署工程师特别有用,他们一方面可能拥有充足的测试或分析数据,可以提供对模型代表性输入的良好认识,但另一方面,他们可能对模型了解不够,无法猜测哪些输入形状应该是动态的。
与原始输入不同的输入形状被视为动态;反之,与原始输入相同的形状被视为静态。此外,我们验证附加输入对于导出的程序是有效的。这保证了用它们代替原始输入进行跟踪会生成相同的图。
示例
args0, kwargs0 = ... # example inputs for export # other representative inputs that the exported program will run on dynamic_shapes = torch.export.AdditionalInputs() dynamic_shapes.add(args1, kwargs1) ... dynamic_shapes.add(argsN, kwargsN) torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes)
- dynamic_shapes(m, args, kwargs=None)[source]#
通过合并原始输入
args()和kwargs()以及每个附加输入 args 和 kwargs 的形状,推断出dynamic_shapes()的 Pytree 结构。
- class torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source]#
Dim类允许用户在导出的程序中指定动态性。通过用Dim标记一个维度,编译器会将该维度与包含动态范围的符号整数关联起来。该 API 可以以两种方式使用:Dim 提示(即自动动态形状:
Dim.AUTO、Dim.DYNAMIC、Dim.STATIC)或命名 Dim(即Dim("name", min=1, max=2))。Dim 提示提供了导出能力的最低门槛,用户只需指定维度是动态的、静态的,还是由编译器决定(
Dim.AUTO)。导出过程将自动推断剩余的关于最小/最大范围以及维度之间关系的约束。示例
class Foo(nn.Module): def forward(self, x, y): assert x.shape[0] == 4 assert y.shape[0] >= 16 return x @ y x = torch.randn(4, 8) y = torch.randn(8, 16) dynamic_shapes = { "x": {0: Dim.AUTO, 1: Dim.AUTO}, "y": {0: Dim.AUTO, 1: Dim.AUTO}, } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)
在这里,如果我们用
Dim.DYNAMIC替换所有Dim.AUTO的用法,导出将引发异常,因为模型已将x.shape[0]约束为静态。维度之间更复杂的关系也可能被编译器编码为运行时断言节点,例如
(x.shape[0] + y.shape[1]) % 4 == 0,如果运行时输入不满足这些约束,将引发该断言。您还可以为 Dim 提示指定最小-最大边界,例如
Dim.AUTO(min=16, max=32)、Dim.DYNAMIC(max=64),编译器将在这些范围内的剩余约束进行推断。如果有效范围完全超出用户指定的范围,将引发异常。命名 Dim 提供了一种更严格的方式来指定动态性,如果编译器推断出的约束与用户规范不匹配,则会引发异常。例如,导出之前的模型,用户将需要以下
dynamic_shapes参数。s0 = Dim("s0") s1 = Dim("s1", min=16) dynamic_shapes = { "x": {0: 4, 1: s0}, "y": {0: s0, 1: s1}, } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)
命名 Dim 还允许指定维度之间的关系,最多为单变量线性关系。例如,以下表示一个维度是另一个维度的倍数加 4。
s0 = Dim("s0") s1 = 3 * s0 + 4
- class torch.export.dynamic_shapes.ShapesCollection[source]#
dynamic_shapes 的构建器。用于为输入中出现的张量分配动态形状规范。
这特别有用,当
args()是嵌套输入结构时,索引输入张量比在dynamic_shapes()规范中复制args()的结构更容易。示例
args = {"x": tensor_x, "others": [tensor_y, tensor_z]} dim = torch.export.Dim(...) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2} # This is equivalent to the following (now auto-generated): # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]} torch.export(..., args, dynamic_shapes=dynamic_shapes)
要为整数指定动态性,我们需要先用 `_IntWrapper` 包装整数,这样我们就可以为每个整数拥有一个“唯一标识符”。
示例
args = {"x": tensor_x, "others": [int_x, int_y]} # Wrap all ints with _IntWrapper mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[mapped_args["others"][0]] = Dim.DYNAMIC # This is equivalent to the following (now auto-generated): # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [Dim.DYNAMIC, None]} torch.export(..., args, dynamic_shapes=dynamic_shapes)
- dynamic_shapes(m, args, kwargs=None)[source]#
根据
args()和kwargs()生成dynamic_shapes()Pytree 结构。
- torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source]#
当使用
dynamic_shapes()进行导出时,如果规范与从跟踪模型推断出的约束不匹配,导出可能会因 ConstraintViolation 错误而失败。错误消息可能会提供建议的修复 - 可以对dynamic_shapes()进行的更改,以成功导出。示例 ConstraintViolation 错误消息
Suggested fixes: dim = Dim('dim', min=3, max=6) # this just refines the dim's range dim = 4 # this specializes to a constant dy = dx + 1 # dy was specified as an independent dim, but is actually tied to dx with this relation
这是一个辅助函数,它接收 ConstraintViolation 错误消息和原始
dynamic_shapes()规范,并返回一个包含建议修复的新dynamic_shapes()规范。使用示例
try: ep = export(mod, args, dynamic_shapes=dynamic_shapes) except torch._dynamo.exc.UserError as exc: new_shapes = refine_dynamic_shapes_from_suggested_fixes( exc.msg, dynamic_shapes ) ep = export(mod, args, dynamic_shapes=new_shapes)
- torch.export.save(ep, f, *, extra_files=None, opset_version=None, pickle_protocol=2)[source]#
警告
正在积极开发中,保存的文件可能无法在新版本的 PyTorch 中使用。
将
ExportedProgram保存到文件类对象。然后可以使用 Python APItorch.export.load加载它。- 参数
ep (ExportedProgram) – 要保存的导出的程序。
f (str | os.PathLike[str] | IO[bytes]) – 实现 write 和 flush 的文件对象,或包含文件名的字符串。
extra_files (Optional[Dict[str, Any]]) – 从文件名到内容的映射,将作为 f 的一部分存储。
opset_version (Optional[Dict[str, int]]) – 一个 opset 名称到该 opset 版本的映射
pickle_protocol (int) – 可以指定以覆盖默认协议
示例
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file torch.export.save(ep, "exported_program.pt2") # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {"foo.txt": b"bar".decode("utf-8")} torch.export.save(ep, "exported_program.pt2", extra_files=extra_files)
- torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]#
警告
正在积极开发中,保存的文件可能无法在新版本的 PyTorch 中使用。
加载之前使用
torch.export.save保存的ExportedProgram。- 参数
- 返回
一个
ExportedProgram对象- 返回类型
示例
import torch import io # Load ExportedProgram from file ep = torch.export.load("exported_program.pt2") # Load ExportedProgram from io.BytesIO object with open("exported_program.pt2", "rb") as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {"foo.txt": ""} # values will be replaced with data ep = torch.export.load("exported_program.pt2", extra_files=extra_files) print(extra_files["foo.txt"]) print(ep(torch.randn(5)))
- torch.export.pt2_archive._package.package_pt2(f, *, exported_programs=None, aoti_files=None, extra_files=None, opset_version=None, pickle_protocol=2)[source]#
将工件保存为 PT2Archive 格式。该工件随后可以使用 `load_pt2` 加载。
- 参数
f (str | os.PathLike[str] | IO[bytes]) – 文件类对象(必须实现 write 和 flush)或包含文件名的字符串。
exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]) – 要保存的导出的程序,或者是一个将模型名称映射到导出的程序的字典。导出的程序将保存在 models/*.json 下。如果只指定了一个 ExportedProgram,它将自动命名为“model”。
aoti_files (Union[list[str], dict[str, list[str]]]) – 由 AOTInductor 通过
torch._inductor.aot_compile(..., {"aot_inductor.package": True})生成的文件列表,或者是一个将模型名称映射到其 AOTInductor 生成文件的字典。如果只指定了一组文件,它将自动命名为“model”。extra_files (Optional[Dict[str, Any]]) – 从文件名到内容的映射,将作为 pt2 的一部分存储。
opset_version (Optional[Dict[str, int]]) – 一个 opset 名称到该 opset 版本的映射
pickle_protocol (int) – 可以指定以覆盖默认协议
- 返回类型
- torch.export.pt2_archive._package.load_pt2(f, *, expected_opset_version=None, run_single_threaded=False, num_runners=1, device_index=-1, load_weights_from_disk=False)[source]#
加载使用 `package_pt2` 保存的所有工件。
- 参数
f (str | os.PathLike[str] | IO[bytes]) – 文件类对象(必须实现 write 和 flush)或包含文件名的字符串。
expected_opset_version (Optional[Dict[str, int]]) – 一个 opset 名称到预期 opset 版本的映射
num_runners (int) – 加载 AOTInductor 工件的运行器数量
run_single_threaded (bool) – 模型是否应在没有线程同步逻辑的情况下运行。这有助于避免与 CUDAGraphs 冲突。
device_index (int) – 将 PT2 包加载到的设备索引。默认情况下,使用 device_index=-1,当使用 CUDA 时,它对应于 cuda 设备。例如,传递 device_index=1 会将包加载到 cuda:1。
- 返回
一个包含 PT2 中所有对象的
PT2ArchiveContents对象。- 返回类型
PT2ArchiveContents
- torch.export.draft_export(mod, args, kwargs=None, *, dynamic_shapes=None, preserve_module_call_signature=(), strict=False, prefer_deferred_runtime_asserts_over_guards=False)[source]#
一个 `torch.export.export` 的版本,旨在始终如一地生成 ExportedProgram,即使存在潜在的健全性问题,并生成一份报告列出发现的问题。
- 返回类型
- class torch.export.unflatten.FlatArgsAdapter[source]#
使用 `input_spec` 调整输入参数,以匹配 `target_spec`。
- class torch.export.unflatten.InterpreterModule(graph, ty=None)[source]#
一个使用 torch.fx.Interpreter 执行的模块,而不是 GraphModule 使用的常规代码生成。这提供了更好的堆栈跟踪信息,并使执行调试更容易。
- class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)[source]#
一个模块,它携带与该模块调用序列相对应的 InterpreterModules 序列。每次调用模块时,它会分派到下一个 InterpreterModule,并在最后一个之后回绕。
- torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source]#
取消扁平化 ExportedProgram,生成一个具有与原始 eager 模块相同模块层次结构的模块。如果您尝试将
torch.export与期望模块层次结构而不是torch.export通常生成的扁平图的其他系统一起使用,这会很有用。注意
未扁平化模块的 args/kwargs 不一定与 eager 模块匹配,因此进行模块交换(例如,
self.submod = new_mod)不一定有效。如果您需要交换模块,则需要设置torch.export.export()的preserve_module_call_signature参数。- 参数
module (ExportedProgram) – 要取消扁平化的 ExportedProgram。
flat_args_adapter (Optional[FlatArgsAdapter]) – 如果输入 TreeSpec 与导出的模块不匹配,则调整扁平参数。
- 返回
一个
UnflattenedModule实例,它具有与导出现象之前的原始 eager 模块相同的模块层次结构。- 返回类型
UnflattenedModule
- torch.export.register_dataclass(cls, *, serialized_type_name=None)[source]#
将数据类注册为
torch.export.export()的有效输入/输出类型。- 参数
示例
import torch from dataclasses import dataclass @dataclass class InputDataClass: feature: torch.Tensor bias: int @dataclass class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) class Mod(torch.nn.Module): def forward(self, x: InputDataClass) -> OutputDataClass: res = x.feature + x.bias return OutputDataClass(res=res) ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),)) print(ep)
- class torch.export.decomp_utils.CustomDecompTable[source]#
这是一个自定义字典,专门用于处理 export 中的 decomp_table。我们需要这个是因为在新世界中,您只能 *删除* 一个 op 来从 decomp table 中保留它。这对于自定义 op 来说是个问题,因为我们不知道自定义 op 何时才能实际加载到调度器中。因此,我们需要记录自定义 op 操作,直到我们真正需要实现它(这发生在运行分解传递时)。
- 我们保持的不变量是:
所有 aten 分解都在初始化时加载
当用户从表中读取时,我们会实现 *所有* op,以便调度器更有可能拾取自定义 op。
如果是写操作,我们不一定实现
我们将在调用 run_decompositions() 之前,在 export 的最后一次加载。
- torch.export.passes.move_to_device_pass(ep, location)[source]#
将导出的程序移动到指定设备。
- 参数
ep (ExportedProgram) – 要移动的导出的程序。
location (Union[torch.device, str, Dict[str, str]]) – 要将导出的程序移动到的设备。如果为字符串,则将其解释为设备名称。如果是字典,则将其解释为从现有设备到目标设备的映射。
- 返回
移动后的导出的程序。
- 返回类型
- class torch.export.pt2_archive.PT2ArchiveReader(archive_path_or_buffer)#
用于读取 PT2 存档的上下文管理器。
- class torch.export.pt2_archive.PT2ArchiveWriter(archive_path_or_buffer)#
用于写入 PT2 存档的上下文管理器。
- class torch.export.exported_program.ModuleCallEntry(fqn: str, signature: Optional[torch.export.exported_program.ModuleCallSignature] = None)[源]#
- class torch.export.exported_program.ModuleCallSignature(inputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec, forward_arg_names: Optional[list[str]] = None)[源]#
- torch.export.exported_program.default_decompositions()[源]#
这是默认的分解表,其中包含所有 ATEN 算子到核心 aten opset 的分解。请将此 API 与
run_decompositions()一起使用。- 返回类型
- class torch.export.custom_obj.ScriptObjectMeta(constant_name, class_fqn)[源]#
存储在代表 ScriptObjects 的节点上的元数据。
- class torch.export.graph_signature.ConstantArgument(name: str, value: Union[int, float, bool, str, NoneType])[源]#
- class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[源]#
- class torch.export.graph_signature.ExportBackwardSignature(gradients_to_parameters: dict[str, str], gradients_to_user_inputs: dict[str, str], loss_output: str)[源]#
- class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[源]#
ExportGraphSignature模拟 Export Graph 的输入/输出签名,这是一个具有更强不变性保证的 fx.Graph。Export Graph 是函数式的,不会通过
getattr节点访问图内的“状态”,例如参数或缓冲区。相反,export()保证参数、缓冲区和常量张量被提升到图外部作为输入。类似地,对缓冲区的任何修改也不会包含在图中,而是将修改后的缓冲区值建模为 Export Graph 的附加输出。所有输入和输出的顺序是
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出以下模块
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer("my_buffer1", torch.tensor(3.0)) self.register_buffer("my_buffer2", torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = ( x1 + self.my_parameter ) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output mod = CustomModule() ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))
产生的图是非函数式的
graph(): %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] %x1 : [num_users=1] = placeholder[target=x1] %x2 : [num_users=1] = placeholder[target=x2] %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) return (add_1,)
非函数式图产生的 ExportGraphSignature 将是
# inputs p_my_parameter: PARAMETER target='my_parameter' b_my_buffer1: BUFFER target='my_buffer1' persistent=True b_my_buffer2: BUFFER target='my_buffer2' persistent=True x1: USER_INPUT x2: USER_INPUT # outputs add_1: USER_OUTPUT
要获得函数式图,您可以使用
run_decompositions()。mod = CustomModule() ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0))) ep = ep.run_decompositions()
产生的图是函数式的
graph(): %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] %x1 : [num_users=1] = placeholder[target=x1] %x2 : [num_users=1] = placeholder[target=x2] %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) return (add_2, add_1)
函数式图产生的 ExportGraphSignature 将是
# inputs p_my_parameter: PARAMETER target='my_parameter' b_my_buffer1: BUFFER target='my_buffer1' persistent=True b_my_buffer2: BUFFER target='my_buffer2' persistent=True x1: USER_INPUT x2: USER_INPUT # outputs add_2: BUFFER_MUTATION target='my_buffer2' add_1: USER_OUTPUT
- property backward_signature: Optional[ExportBackwardSignature]#
- property buffers: Collection[str]#
- input_specs: list[torch.export.graph_signature.InputSpec]#
- property input_tokens: Collection[str]#
- property lifted_custom_objs: Collection[str]#
- property lifted_tensor_constants: Collection[str]#
- property non_persistent_buffers: Collection[str]#
- output_specs: list[torch.export.graph_signature.OutputSpec]#
- property output_tokens: Collection[str]#
- property parameters: Collection[str]#
- class torch.export.graph_signature.InputKind(value)[源]#
一个枚举。
- BUFFER = 3#
- CONSTANT_TENSOR = 4#
- CUSTOM_OBJ = 5#
- PARAMETER = 2#
- TOKEN = 6#
- USER_INPUT = 1#
- class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str], persistent: Optional[bool] = None)[源]#
- class torch.export.graph_signature.OutputKind(value)[源]#
一个枚举。
- BUFFER_MUTATION = 3#
- GRADIENT_TO_PARAMETER = 5#
- GRADIENT_TO_USER_INPUT = 6#
- LOSS_OUTPUT = 2#
- PARAMETER_MUTATION = 4#
- TOKEN = 8#
- USER_INPUT_MUTATION = 7#
- USER_OUTPUT = 1#
- class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[源]#
- arg: Union[TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument, ConstantArgument, CustomObjArgument, TokenArgument]#
- kind: OutputKind#