评价此页

torch.export API 参考#

创建于: 2025年7月17日 | 最后更新于: 2025年7月17日

torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=False, preserve_module_call_signature=(), prefer_deferred_runtime_asserts_over_guards=False)[source]#

export() 接收任意 nn.Module 和示例输入,以提前(AOT)方式生成一个仅表示函数张量计算的跟踪图,之后可以使用不同的输入执行或序列化该图。跟踪图 (1) 在函数式 ATen 运算符集中生成规范化的运算符(以及用户指定的任何自定义运算符),(2) 消除了所有 Python 控制流和数据结构(某些例外情况除外),并且 (3) 记录了显示这种规范化和控制流消除对于未来输入是健全的形状约束。

健全性保证

在跟踪期间,export() 会记录用户程序和底层 PyTorch 运算符内核所做的与形状相关的假设。只有当这些假设成立时,生成的 ExportedProgram 才被认为是有效的。

跟踪会做出关于输入张量形状(而非值)的假设。为了使 export() 成功,这些假设必须在图捕获时进行验证。具体来说:

  • 对输入张量静态形状的假设无需额外工作即可自动验证。

  • 对输入张量动态形状的假设需要通过使用 Dim() API 来构造动态维度,并通过 dynamic_shapes 参数将其与示例输入关联来显式指定。

如果任何假设无法验证,将引发致命错误。发生这种情况时,错误消息将包含对验证假设所需的规范的建议修复。例如,export() 可能会为动态维度 dim0_x 的定义提出以下修复,该维度出现在输入 x 的形状中,该维度以前定义为 Dim("dim0_x")

dim = Dim("dim0_x", max=5)

此示例意味着生成的代码要求输入 x 的维度 0 小于或等于 5 才能有效。您可以检查动态维度定义的建议修复,然后将其逐字复制到您的代码中,而无需更改传递给 export() 调用的 dynamic_shapes 参数。

参数
  • mod (Module) – 我们将跟踪此模块的 forward 方法。

  • args (tuple[Any, ...]) – 示例位置输入。

  • kwargs (Optional[Mapping[str, Any]]) – 可选的示例关键字输入。

  • dynamic_shapes (Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]]) –

    一个可选参数,其类型应为:1)一个字典,从 f 的参数名称映射到其动态形状规范,2)一个元组,指定按原始顺序排列的每个输入的动态形状规范。如果您正在为关键字参数指定动态性,则需要按照函数原始签名中定义的顺序传递它们。

    张量参数的动态形状可以指定为:1)一个从动态维度索引到 Dim() 类型的字典,其中不需要在此字典中包含静态维度索引,但当它们存在时,应将其映射到 None;或 2)一个 Dim() 类型或 None 的元组/列表,其中 Dim() 类型对应动态维度,静态维度由 None 表示。由字典或张量元组/列表组成的参数通过使用包含的规范的映射或序列来递归指定。

  • strict (bool) – 当禁用(默认)时,export 函数将通过 Python 运行时跟踪程序,这本身不会验证图中的一些隐式假设。它仍然会验证大多数关键假设,例如形状安全性。当启用(通过设置 strict=True)时,export 函数将通过 TorchDynamo 跟踪程序,这将确保生成图的健全性。TorchDynamo 对 Python 特性的覆盖有限,因此您可能会遇到更多错误。请注意,切换此参数不会影响生成的 IR 规范,模型将以相同的方式序列化,无论此处传递什么值。

  • preserve_module_call_signature (tuple[str, ...]) – 一个子模块路径列表,将保留其原始调用约定作为元数据。调用 torch.export.unflatten 时将使用元数据来保留模块的原始调用约定。

返回

一个包含跟踪的可调用对象的 ExportedProgram

返回类型

ExportedProgram

可接受的输入/输出类型

可接受的输入(用于 argskwargs)和输出类型包括:

  • 基本类型,即 torch.Tensorintfloatboolstr

  • 数据类,但它们必须首先通过调用 register_dataclass() 进行注册。

  • (嵌套的) 数据结构,由 dictlisttuplenamedtupleOrderedDict 组成,其中包含上述所有类型。

class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[source]#

来自 export() 的程序包。它包含一个 torch.fx.Graph,该图表示张量计算,一个包含所有提升的参数和缓冲区张量值的 state_dict,以及各种元数据。

您可以使用与 export() 跟踪的原始可调用对象相同的调用约定来调用 ExportedProgram。

要对图执行转换,请使用 `.module` 属性访问 torch.fx.GraphModule。然后,您可以使用 FX 转换 来重写图。之后,您只需再次使用 export() 即可构建一个正确的 ExportedProgram。

buffers()[source]#

返回原始模块缓冲区的迭代器。

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

返回类型

Iterator[Tensor]

property call_spec#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

property constants#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

property dialect: str#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

property example_inputs#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

property graph#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

property graph_module#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

property graph_signature#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

module(check_guards=True)[source]#

返回一个自包含的 GraphModule,其中所有参数/缓冲区都已内联。

  • check_guards=True (默认) 时,将生成一个 _guards_fn 子模块,并在图中的占位符之后插入一个对 _guards_fn 子模块的调用。此模块检查输入的 guard。

  • check_guards=False 时,一部分检查将由图模块的 forward pre-hook 执行。不会生成 _guards_fn 子模块。

返回类型

GraphModule

property module_call_graph#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

named_buffers()[source]#

返回一个迭代器,其中包含原始模块缓冲区,同时生成缓冲区的名称以及缓冲区本身。

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

返回类型

Iterator[tuple[str, torch.Tensor]]

named_parameters()[source]#

返回一个迭代器,其中包含原始模块参数,同时生成参数的名称以及参数本身。

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

返回类型

Iterator[tuple[str, torch.nn.parameter.Parameter]]

parameters()[source]#

返回原始模块参数的迭代器。

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

返回类型

Iterator[Parameter]

property range_constraints#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

run_decompositions(decomp_table=None, decompose_custom_triton_ops=False)[source]#

对导出的程序运行一组分解,并返回一个新的导出的程序。默认情况下,我们将运行核心 ATen 分解,以在 Core ATen Operator Set 中获得运算符。

目前,我们不分解联合图。

参数

decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]) – 一个可选参数,指定 Aten ops 的分解行为 (1) 如果为 None,我们分解为核心 aten 分解 (2) 如果为空,我们不分解任何运算符

返回类型

ExportedProgram

一些例子

如果您不想分解任何内容

ep = torch.export.export(model, ...)
ep = ep.run_decompositions(decomp_table={})

如果您想获取核心 aten 运算符集,但排除某些运算符,您可以这样做:

ep = torch.export.export(model, ...)
decomp_table = torch.export.default_decompositions()
decomp_table[your_op] = your_custom_decomp
ep = ep.run_decompositions(decomp_table=decomp_table)
property state_dict#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

property tensor_constants#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

validate()[source]#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

property verifier: Any#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

property verifiers#

警告

此 API 仍处于实验阶段,并且 *不* 向后兼容。

class torch.export.dynamic_shapes.AdditionalInputs[source]#

根据附加输入推断 dynamic_shapes。

这对于部署工程师特别有用,他们一方面可能拥有充足的测试或分析数据,可以提供对模型代表性输入的良好认识,但另一方面,他们可能对模型了解不够,无法猜测哪些输入形状应该是动态的。

与原始输入不同的输入形状被视为动态;反之,与原始输入相同的形状被视为静态。此外,我们验证附加输入对于导出的程序是有效的。这保证了用它们代替原始输入进行跟踪会生成相同的图。

示例

args0, kwargs0 = ...  # example inputs for export

# other representative inputs that the exported program will run on
dynamic_shapes = torch.export.AdditionalInputs()
dynamic_shapes.add(args1, kwargs1)
...
dynamic_shapes.add(argsN, kwargsN)

torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes)
add(args, kwargs=None)[source]#

附加输入 args()kwargs()

dynamic_shapes(m, args, kwargs=None)[source]#

通过合并原始输入 args()kwargs() 以及每个附加输入 args 和 kwargs 的形状,推断出 dynamic_shapes() 的 Pytree 结构。

verify(ep)[source]#

验证导出的程序对于每个附加输入是否有效。

class torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source]#

Dim 类允许用户在导出的程序中指定动态性。通过用 Dim 标记一个维度,编译器会将该维度与包含动态范围的符号整数关联起来。

该 API 可以以两种方式使用:Dim 提示(即自动动态形状:Dim.AUTODim.DYNAMICDim.STATIC)或命名 Dim(即 Dim("name", min=1, max=2))。

Dim 提示提供了导出能力的最低门槛,用户只需指定维度是动态的、静态的,还是由编译器决定(Dim.AUTO)。导出过程将自动推断剩余的关于最小/最大范围以及维度之间关系的约束。

示例

class Foo(nn.Module):
    def forward(self, x, y):
        assert x.shape[0] == 4
        assert y.shape[0] >= 16
        return x @ y


x = torch.randn(4, 8)
y = torch.randn(8, 16)
dynamic_shapes = {
    "x": {0: Dim.AUTO, 1: Dim.AUTO},
    "y": {0: Dim.AUTO, 1: Dim.AUTO},
}
ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)

在这里,如果我们用 Dim.DYNAMIC 替换所有 Dim.AUTO 的用法,导出将引发异常,因为模型已将 x.shape[0] 约束为静态。

维度之间更复杂的关系也可能被编译器编码为运行时断言节点,例如 (x.shape[0] + y.shape[1]) % 4 == 0,如果运行时输入不满足这些约束,将引发该断言。

您还可以为 Dim 提示指定最小-最大边界,例如 Dim.AUTO(min=16, max=32)Dim.DYNAMIC(max=64),编译器将在这些范围内的剩余约束进行推断。如果有效范围完全超出用户指定的范围,将引发异常。

命名 Dim 提供了一种更严格的方式来指定动态性,如果编译器推断出的约束与用户规范不匹配,则会引发异常。例如,导出之前的模型,用户将需要以下 dynamic_shapes 参数。

s0 = Dim("s0")
s1 = Dim("s1", min=16)
dynamic_shapes = {
    "x": {0: 4, 1: s0},
    "y": {0: s0, 1: s1},
}
ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)

命名 Dim 还允许指定维度之间的关系,最多为单变量线性关系。例如,以下表示一个维度是另一个维度的倍数加 4。

s0 = Dim("s0")
s1 = 3 * s0 + 4
class torch.export.dynamic_shapes.ShapesCollection[source]#

dynamic_shapes 的构建器。用于为输入中出现的张量分配动态形状规范。

这特别有用,当 args() 是嵌套输入结构时,索引输入张量比在 dynamic_shapes() 规范中复制 args() 的结构更容易。

示例

args = {"x": tensor_x, "others": [tensor_y, tensor_z]}

dim = torch.export.Dim(...)
dynamic_shapes = torch.export.ShapesCollection()
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
dynamic_shapes[tensor_y] = {0: dim * 2}
# This is equivalent to the following (now auto-generated):
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

要为整数指定动态性,我们需要先用 `_IntWrapper` 包装整数,这样我们就可以为每个整数拥有一个“唯一标识符”。

示例

args = {"x": tensor_x, "others": [int_x, int_y]}
# Wrap all ints with _IntWrapper
mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args)

dynamic_shapes = torch.export.ShapesCollection()
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
dynamic_shapes[mapped_args["others"][0]] = Dim.DYNAMIC

# This is equivalent to the following (now auto-generated):
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [Dim.DYNAMIC, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)
dynamic_shapes(m, args, kwargs=None)[source]#

根据 args()kwargs() 生成 dynamic_shapes() Pytree 结构。

torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source]#

当使用 dynamic_shapes() 进行导出时,如果规范与从跟踪模型推断出的约束不匹配,导出可能会因 ConstraintViolation 错误而失败。错误消息可能会提供建议的修复 - 可以对 dynamic_shapes() 进行的更改,以成功导出。

示例 ConstraintViolation 错误消息

Suggested fixes:

    dim = Dim('dim', min=3, max=6)  # this just refines the dim's range
    dim = 4  # this specializes to a constant
    dy = dx + 1  # dy was specified as an independent dim, but is actually tied to dx with this relation

这是一个辅助函数,它接收 ConstraintViolation 错误消息和原始 dynamic_shapes() 规范,并返回一个包含建议修复的新 dynamic_shapes() 规范。

使用示例

try:
    ep = export(mod, args, dynamic_shapes=dynamic_shapes)
except torch._dynamo.exc.UserError as exc:
    new_shapes = refine_dynamic_shapes_from_suggested_fixes(
        exc.msg, dynamic_shapes
    )
    ep = export(mod, args, dynamic_shapes=new_shapes)
返回类型

Union[dict[str, Any], tuple[Any], list[Any]]

torch.export.save(ep, f, *, extra_files=None, opset_version=None, pickle_protocol=2)[source]#

警告

正在积极开发中,保存的文件可能无法在新版本的 PyTorch 中使用。

ExportedProgram 保存到文件类对象。然后可以使用 Python API torch.export.load 加载它。

参数
  • ep (ExportedProgram) – 要保存的导出的程序。

  • f (str | os.PathLike[str] | IO[bytes]) – 实现 write 和 flush 的文件对象,或包含文件名的字符串。

  • extra_files (Optional[Dict[str, Any]]) – 从文件名到内容的映射,将作为 f 的一部分存储。

  • opset_version (Optional[Dict[str, int]]) – 一个 opset 名称到该 opset 版本的映射

  • pickle_protocol (int) – 可以指定以覆盖默认协议

示例

import torch
import io


class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10


ep = torch.export.export(MyModule(), (torch.randn(5),))

# Save to file
torch.export.save(ep, "exported_program.pt2")

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)

# Save with extra files
extra_files = {"foo.txt": b"bar".decode("utf-8")}
torch.export.save(ep, "exported_program.pt2", extra_files=extra_files)
torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]#

警告

正在积极开发中,保存的文件可能无法在新版本的 PyTorch 中使用。

加载之前使用 torch.export.save 保存的 ExportedProgram

参数
  • f (str | os.PathLike[str] | IO[bytes]) – 文件类对象(必须实现 write 和 flush)或包含文件名的字符串。

  • extra_files (Optional[Dict[str, Any]]) – 此映射中提供的附加文件名将被加载,其内容将存储在提供的映射中。

  • expected_opset_version (Optional[Dict[str, int]]) – 一个 opset 名称到预期 opset 版本的映射

返回

一个 ExportedProgram 对象

返回类型

ExportedProgram

示例

import torch
import io

# Load ExportedProgram from file
ep = torch.export.load("exported_program.pt2")

# Load ExportedProgram from io.BytesIO object
with open("exported_program.pt2", "rb") as f:
    buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)

# Load with extra files.
extra_files = {"foo.txt": ""}  # values will be replaced with data
ep = torch.export.load("exported_program.pt2", extra_files=extra_files)
print(extra_files["foo.txt"])
print(ep(torch.randn(5)))
torch.export.pt2_archive._package.package_pt2(f, *, exported_programs=None, aoti_files=None, extra_files=None, opset_version=None, pickle_protocol=2)[source]#

将工件保存为 PT2Archive 格式。该工件随后可以使用 `load_pt2` 加载。

参数
  • f (str | os.PathLike[str] | IO[bytes]) – 文件类对象(必须实现 write 和 flush)或包含文件名的字符串。

  • exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]) – 要保存的导出的程序,或者是一个将模型名称映射到导出的程序的字典。导出的程序将保存在 models/*.json 下。如果只指定了一个 ExportedProgram,它将自动命名为“model”。

  • aoti_files (Union[list[str], dict[str, list[str]]]) – 由 AOTInductor 通过 torch._inductor.aot_compile(..., {"aot_inductor.package": True}) 生成的文件列表,或者是一个将模型名称映射到其 AOTInductor 生成文件的字典。如果只指定了一组文件,它将自动命名为“model”。

  • extra_files (Optional[Dict[str, Any]]) – 从文件名到内容的映射,将作为 pt2 的一部分存储。

  • opset_version (Optional[Dict[str, int]]) – 一个 opset 名称到该 opset 版本的映射

  • pickle_protocol (int) – 可以指定以覆盖默认协议

返回类型

Union[str, PathLike[str], IO[bytes]]

torch.export.pt2_archive._package.load_pt2(f, *, expected_opset_version=None, run_single_threaded=False, num_runners=1, device_index=-1, load_weights_from_disk=False)[source]#

加载使用 `package_pt2` 保存的所有工件。

参数
  • f (str | os.PathLike[str] | IO[bytes]) – 文件类对象(必须实现 write 和 flush)或包含文件名的字符串。

  • expected_opset_version (Optional[Dict[str, int]]) – 一个 opset 名称到预期 opset 版本的映射

  • num_runners (int) – 加载 AOTInductor 工件的运行器数量

  • run_single_threaded (bool) – 模型是否应在没有线程同步逻辑的情况下运行。这有助于避免与 CUDAGraphs 冲突。

  • device_index (int) – 将 PT2 包加载到的设备索引。默认情况下,使用 device_index=-1,当使用 CUDA 时,它对应于 cuda 设备。例如,传递 device_index=1 会将包加载到 cuda:1

返回

一个包含 PT2 中所有对象的 PT2ArchiveContents 对象。

返回类型

PT2ArchiveContents

torch.export.draft_export(mod, args, kwargs=None, *, dynamic_shapes=None, preserve_module_call_signature=(), strict=False, prefer_deferred_runtime_asserts_over_guards=False)[source]#

一个 `torch.export.export` 的版本,旨在始终如一地生成 ExportedProgram,即使存在潜在的健全性问题,并生成一份报告列出发现的问题。

返回类型

ExportedProgram

class torch.export.unflatten.FlatArgsAdapter[source]#

使用 `input_spec` 调整输入参数,以匹配 `target_spec`。

abstract adapt(target_spec, input_spec, input_args, metadata=None, obj=None)[source]#

注意:此适配器可能会修改给定的 `input_args_with_path`。

返回类型

list[Any]

get_flat_arg_paths()[source]#

返回用于访问扁平参数的路径列表。

返回类型

list[str]

class torch.export.unflatten.InterpreterModule(graph, ty=None)[source]#

一个使用 torch.fx.Interpreter 执行的模块,而不是 GraphModule 使用的常规代码生成。这提供了更好的堆栈跟踪信息,并使执行调试更容易。

class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)[source]#

一个模块,它携带与该模块调用序列相对应的 InterpreterModules 序列。每次调用模块时,它会分派到下一个 InterpreterModule,并在最后一个之后回绕。

torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source]#

取消扁平化 ExportedProgram,生成一个具有与原始 eager 模块相同模块层次结构的模块。如果您尝试将 torch.export 与期望模块层次结构而不是 torch.export 通常生成的扁平图的其他系统一起使用,这会很有用。

注意

未扁平化模块的 args/kwargs 不一定与 eager 模块匹配,因此进行模块交换(例如,self.submod = new_mod)不一定有效。如果您需要交换模块,则需要设置 torch.export.export()preserve_module_call_signature 参数。

参数
  • module (ExportedProgram) – 要取消扁平化的 ExportedProgram。

  • flat_args_adapter (Optional[FlatArgsAdapter]) – 如果输入 TreeSpec 与导出的模块不匹配,则调整扁平参数。

返回

一个 UnflattenedModule 实例,它具有与导出现象之前的原始 eager 模块相同的模块层次结构。

返回类型

UnflattenedModule

torch.export.register_dataclass(cls, *, serialized_type_name=None)[source]#

将数据类注册为 torch.export.export() 的有效输入/输出类型。

参数
  • cls (type[Any]) – 要注册的数据类类型

  • serialized_type_name (Optional[str]) – 数据类的序列化名称。这是

  • this (required if you want to serialize the pytree TreeSpec containing) – (如果要序列化包含数据类的 Pytree TreeSpec,则需要此项)

  • dataclass.

示例

import torch
from dataclasses import dataclass


@dataclass
class InputDataClass:
    feature: torch.Tensor
    bias: int


@dataclass
class OutputDataClass:
    res: torch.Tensor


torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)


class Mod(torch.nn.Module):
    def forward(self, x: InputDataClass) -> OutputDataClass:
        res = x.feature + x.bias
        return OutputDataClass(res=res)


ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),))
print(ep)
class torch.export.decomp_utils.CustomDecompTable[source]#

这是一个自定义字典,专门用于处理 export 中的 decomp_table。我们需要这个是因为在新世界中,您只能 *删除* 一个 op 来从 decomp table 中保留它。这对于自定义 op 来说是个问题,因为我们不知道自定义 op 何时才能实际加载到调度器中。因此,我们需要记录自定义 op 操作,直到我们真正需要实现它(这发生在运行分解传递时)。

我们保持的不变量是:
  1. 所有 aten 分解都在初始化时加载

  2. 当用户从表中读取时,我们会实现 *所有* op,以便调度器更有可能拾取自定义 op。

  3. 如果是写操作,我们不一定实现

  4. 我们将在调用 run_decompositions() 之前,在 export 的最后一次加载。

copy()[source]#
返回类型

CustomDecompTable

items()[source]#
keys()[source]#
materialize()[source]#
返回类型

dict[torch._ops.OperatorBase, Callable]

pop(*args)[source]#
update(other_dict)[source]#
torch.export.passes.move_to_device_pass(ep, location)[source]#

将导出的程序移动到指定设备。

参数
  • ep (ExportedProgram) – 要移动的导出的程序。

  • location (Union[torch.device, str, Dict[str, str]]) – 要将导出的程序移动到的设备。如果为字符串,则将其解释为设备名称。如果是字典,则将其解释为从现有设备到目标设备的映射。

返回

移动后的导出的程序。

返回类型

ExportedProgram

class torch.export.pt2_archive.PT2ArchiveReader(archive_path_or_buffer)#

用于读取 PT2 存档的上下文管理器。

archive_version()[source]#

获取存档版本。

返回类型

int

get_file_names()[source]#

获取存档中的文件名。

返回类型

list[str]

read_bytes(name)[source]#

从存档中读取字节对象。name:存档内的源文件名。

返回类型

字节

read_string(name)[source]#

从存档中读取字符串对象。name:存档内的源文件名。

返回类型

str

class torch.export.pt2_archive.PT2ArchiveWriter(archive_path_or_buffer)#

用于写入 PT2 存档的上下文管理器。

close()[source]#

关闭存档。

count_prefix(prefix)[source]#

计算以给定前缀开头记录的数量。

返回类型

int

has_record(name)[source]#

检查存档中是否存在记录。

返回类型

布尔值

write_bytes(name, data)[source]#

将字节对象写入存档。name:存档内的目标文件名。data:要写入的字节对象。

write_file(name, file_path)[source]#

将文件复制到存档中。name:存档内的目标文件名。file_path:磁盘上的源文件。

write_folder(archive_dir, folder_dir)[source]#

将文件夹复制到归档中。archive_dir:归档内的目标文件夹。folder_dir:磁盘上的源文件夹。

write_string(name, data)[源]#

将字符串对象写入归档。name:归档内的目标文件名。data:要写入的字符串对象。

torch.export.pt2_archive.is_pt2_package(serialized_model)[源]#

检查序列化模型是否为 PT2 归档包。

返回类型

布尔值

class torch.export.exported_program.ModuleCallEntry(fqn: str, signature: Optional[torch.export.exported_program.ModuleCallSignature] = None)[源]#
class torch.export.exported_program.ModuleCallSignature(inputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec, forward_arg_names: Optional[list[str]] = None)[源]#
torch.export.exported_program.default_decompositions()[源]#

这是默认的分解表,其中包含所有 ATEN 算子到核心 aten opset 的分解。请将此 API 与 run_decompositions() 一起使用。

返回类型

CustomDecompTable

class torch.export.custom_obj.ScriptObjectMeta(constant_name, class_fqn)[源]#

存储在代表 ScriptObjects 的节点上的元数据。

class torch.export.graph_signature.ConstantArgument(name: str, value: Union[int, float, bool, str, NoneType])[源]#
name: str#
value: Optional[Union[int, float, bool, str]]#
class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[源]#
class_fqn: str#
fake_val: Optional[FakeScriptObject] = None#
name: str#
class torch.export.graph_signature.ExportBackwardSignature(gradients_to_parameters: dict[str, str], gradients_to_user_inputs: dict[str, str], loss_output: str)[源]#
gradients_to_parameters: dict[str, str]#
gradients_to_user_inputs: dict[str, str]#
loss_output: str#
class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[源]#

ExportGraphSignature 模拟 Export Graph 的输入/输出签名,这是一个具有更强不变性保证的 fx.Graph。

Export Graph 是函数式的,不会通过 getattr 节点访问图内的“状态”,例如参数或缓冲区。相反,export() 保证参数、缓冲区和常量张量被提升到图外部作为输入。类似地,对缓冲区的任何修改也不会包含在图中,而是将修改后的缓冲区值建模为 Export Graph 的附加输出。

所有输入和输出的顺序是

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

例如,如果导出以下模块

class CustomModule(nn.Module):
    def __init__(self) -> None:
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer("my_buffer1", torch.tensor(3.0))
        self.register_buffer("my_buffer2", torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (
            x1 + self.my_parameter
        ) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0)  # In-place addition

        return output


mod = CustomModule()
ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))

产生的图是非函数式的

graph():
    %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
    %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
    %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
    %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
    return (add_1,)

非函数式图产生的 ExportGraphSignature 将是

# inputs
p_my_parameter: PARAMETER target='my_parameter'
b_my_buffer1: BUFFER target='my_buffer1' persistent=True
b_my_buffer2: BUFFER target='my_buffer2' persistent=True
x1: USER_INPUT
x2: USER_INPUT

# outputs
add_1: USER_OUTPUT

要获得函数式图,您可以使用 run_decompositions()

mod = CustomModule()
ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))
ep = ep.run_decompositions()

产生的图是函数式的

graph():
    %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
    %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
    %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
    return (add_2, add_1)

函数式图产生的 ExportGraphSignature 将是

# inputs
p_my_parameter: PARAMETER target='my_parameter'
b_my_buffer1: BUFFER target='my_buffer1' persistent=True
b_my_buffer2: BUFFER target='my_buffer2' persistent=True
x1: USER_INPUT
x2: USER_INPUT

# outputs
add_2: BUFFER_MUTATION target='my_buffer2'
add_1: USER_OUTPUT
property assertion_dep_token: Optional[Mapping[int, str]]#
property backward_signature: Optional[ExportBackwardSignature]#
property buffers: Collection[str]#
property buffers_to_mutate: Mapping[str, str]#
get_replace_hook(replace_inputs=False)[源]#
input_specs: list[torch.export.graph_signature.InputSpec]#
property input_tokens: Collection[str]#
property inputs_to_buffers: Mapping[str, str]#
property inputs_to_lifted_custom_objs: Mapping[str, str]#
property inputs_to_lifted_tensor_constants: Mapping[str, str]#
property inputs_to_parameters: Mapping[str, str]#
property lifted_custom_objs: Collection[str]#
property lifted_tensor_constants: Collection[str]#
property non_persistent_buffers: Collection[str]#
output_specs: list[torch.export.graph_signature.OutputSpec]#
property output_tokens: Collection[str]#
property parameters: Collection[str]#
property parameters_to_mutate: Mapping[str, str]#
replace_all_uses(old, new)[源]#

在签名中用新名称替换所有旧名称的使用。

property user_inputs: Collection[Union[int, float, bool, None, str]]#
property user_inputs_to_mutate: Mapping[str, str]#
property user_outputs: Collection[Union[int, float, bool, None, str]]#
class torch.export.graph_signature.InputKind(value)[源]#

一个枚举。

BUFFER = 3#
CONSTANT_TENSOR = 4#
CUSTOM_OBJ = 5#
PARAMETER = 2#
TOKEN = 6#
USER_INPUT = 1#
class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str], persistent: Optional[bool] = None)[源]#
arg: Union[TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument, ConstantArgument, CustomObjArgument, TokenArgument]#
kind: InputKind#
persistent: Optional[bool] = None#
target: Optional[str]#
class torch.export.graph_signature.OutputKind(value)[源]#

一个枚举。

BUFFER_MUTATION = 3#
GRADIENT_TO_PARAMETER = 5#
GRADIENT_TO_USER_INPUT = 6#
LOSS_OUTPUT = 2#
PARAMETER_MUTATION = 4#
TOKEN = 8#
USER_INPUT_MUTATION = 7#
USER_OUTPUT = 1#
class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[源]#
arg: Union[TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument, ConstantArgument, CustomObjArgument, TokenArgument]#
kind: OutputKind#
target: Optional[str]#
class torch.export.graph_signature.SymBoolArgument(name: str)[源]#
name: str#
class torch.export.graph_signature.SymFloatArgument(name: str)[源]#
name: str#
class torch.export.graph_signature.SymIntArgument(name: str)[源]#
name: str#
class torch.export.graph_signature.TensorArgument(name: str)[源]#
name: str#
class torch.export.graph_signature.TokenArgument(name: str)[源]#
name: str#