评价此页

torch.onnx.ops#

创建于:2025 年 6 月 10 日 | 最后更新于:2025 年 6 月 20 日

将 ONNX 算子作为原生 torch.fx 算子。

此模块提供了一组函数,用于在 FX 图中创建可导出到 ONNX 的 ONNX 算子。

符号算子#

可用于在 FX 图中以符号方式创建任何 ONNX 算子的算子。这些算子不执行实际计算。建议您在 if torch.onnx.is_in_onnx_export 块内使用它们。

torch.onnx.ops.symbolic(domain_op, /, inputs, attrs=None, *, dtype, shape, version=None, metadata_props=None)[source]#

创建符号 FX 算子以表示任意 ONNX 算子。

此函数用于创建具有单个输出的符号算子。要创建具有多个输出的算子,请使用 symbolic_multi_out()

您可以利用 if torch.onnx.is_in_onnx_export() 来仅在 torch.onnx.export() 期间有条件地启用符号逻辑。

示例

class CustomOp(torch.nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Normal torch operators can interleave with the symbolic ops during ONNX export
        x = x + 1

        # Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
        # The output tensor will have the specified dtype and shape
        val = torch.onnx.ops.symbolic(
            "custom_domain::CustomOp",
            (x,),
            dict(attr_key="attr_value"),
            dtype=x.dtype,
            shape=x.shape,
            version=1,
        )

        # The result of the symbolic op can be used in normal torch operations during ONNX export
        return torch.nn.functional.relu(val)


# You may then export this model to ONNX using torch.onnx.export(..., dynamo=True).
参数
  • domain_op (str) – 域和算子名称,用“::”分隔。例如,“custom_domain::CustomOp”。

  • inputs (Sequence[torch.Tensor | None]) – 算子的输入张量。

  • attrs (dict[str, int | float | str | bool | Sequence[int] | Sequence[float] | Sequence[str] | Sequence[bool]] | None) – 算子的属性。键是属性名,值是属性值。有效的属性类型为 int、float、str、bool,以及 int、float、str 和 bool 的列表。不支持张量属性。

  • dtype (torch.dtype | int) – 输出张量的数据类型。可以是 torch.dtype 或表示 ONNX 数据类型的整数。

  • shape (Sequence[int | torch.SymInt]) – 输出张量的形状。可以是由整数或 SymInt 值组成的列表。

  • version (int | None) – 用于该算子的 opset 版本。

  • metadata_props (dict[str, str] | None) – ONNX 节点的元数据属性。这是一个 str-str 对的字典。

返回

算子的输出张量。

返回类型

torch.Tensor

torch.onnx.ops.symbolic_multi_out(domain_op, /, inputs, attrs=None, *, dtypes, shapes, version=None, metadata_props=None)[source]#

创建符号 FX 算子以表示具有多个输出的任意 ONNX 算子。

您可以利用 if torch.onnx.is_in_onnx_export() 来仅在 torch.onnx.export() 期间有条件地启用符号逻辑。

示例

class CustomOp(torch.nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Normal torch operators can interleave with the symbolic ops during ONNX export
        x = x + 1

        # Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
        # The output tensors will have the specified dtypes and shapes
        (out1, out2) = torch.onnx.ops.symbolic(
            "custom_domain::CustomOp",
            (x,),
            dict(attr_key="attr_value"),
            dtypes=(x.dtype, torch.float32),
            shapes=(x.shape, [1, 2, 3]),
            version=1,
        )

        # The result of the symbolic op can be used in normal torch operations during ONNX export
        return torch.nn.functional.relu(out1 + out2)


# You may then export this model to ONNX using torch.onnx.export(..., dynamo=True).
参数
  • domain_op (str) – 域和算子名称,用“::”分隔。例如,“custom_domain::CustomOp”。

  • inputs (Sequence[torch.Tensor | None]) – 算子的输入张量。

  • attrs (dict[str, int | float | str | bool | Sequence[int] | Sequence[float] | Sequence[str] | Sequence[bool]] | None) – 算子的属性。键是属性名,值是属性值。有效的属性类型为 int、float、str、bool,以及 int、float、str 和 bool 的列表。不支持张量属性。

  • dtypes (Sequence[torch.dtype | int]) – 输出张量的数据类型。可以是 torch.dtype 或表示 ONNX 数据类型的整数列表。此列表的长度必须等于输出的数量。

  • shapes (Sequence[Sequence[int | torch.SymInt]]) – 输出张量的形状。可以是由整数或 SymInt 值组成的列表的列表。此列表的长度必须等于输出的数量。

  • version (int | None) – 用于该算子的 opset 版本。

  • metadata_props (dict[str, str] | None) – ONNX 节点的元数据属性。这是一个 str-str 对的字典。

返回

算子的输出张量列表。

返回类型

Sequence[torch.Tensor]

ONNX 算子#

以下算子已作为原生 PyTorch 算子实现,并可以导出为 ONNX 算子。它们可以在 nn.Module 中原生使用。

例如,您可以定义一个模块

class Model(torch.nn.Module):
    def forward(
        self, input_data, cos_cache_data, sin_cache_data, position_ids_data
    ):
        return torch.onnx.ops.rotary_embedding(
            input_data,
            cos_cache_data,
            sin_cache_data,
            position_ids_data,
        )

并使用以下命令将其导出到 ONNX

input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 3)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
dynamic_shapes = {
    "input_data": {0: torch.export.Dim.DYNAMIC},
    "cos_cache_data": None,
    "sin_cache_data": None,
    "position_ids_data": {0: torch.export.Dim.DYNAMIC},
}
onnx_program = torch.onnx.export(
    model,
    (input_data, cos_cache_data, sin_cache_data, position_ids_data),
    dynamic_shapes=dynamic_shapes,
    dynamo=True,
    opset_version=23,
)

打印 ONNX 程序将显示图中使用的 ONNX 算子

<...>

graph(
    name=main_graph,
    inputs=(
        %"input_data"<FLOAT,[s0,3,4,8]>,
        %"cos_cache_data"<FLOAT,[50,4]>,
        %"sin_cache_data"<FLOAT,[50,4]>,
        %"position_ids_data"<INT64,[s0,3]>
    ),
    outputs=(
        %"rotary_embedding"<FLOAT,[s0,3,4,8]>
    ),
) {
    0 |  # rotary_embedding
        %"rotary_embedding"<FLOAT,[s0,3,4,8]> ⬅️ ::RotaryEmbedding(%"input_data", %"cos_cache_data", %"sin_cache_data", %"position_ids_data")
    return %"rotary_embedding"<FLOAT,[s0,3,4,8]>
}

以及相应的 ExportedProgram

ExportedProgram

class GraphModule(torch.nn.Module):
    def forward(self, input_data: "f32[s0, 3, 4, 8]", cos_cache_data: "f32[50, 4]", sin_cache_data: "f32[50, 4]", position_ids_data: "i64[s0, 3]"):
        rotary_embedding: "f32[s0, 3, 4, 8]" = torch.ops.onnx.RotaryEmbedding.opset23(input_data, cos_cache_data, sin_cache_data, position_ids_data);  input_data = cos_cache_data = sin_cache_data = position_ids_data = None
        return (rotary_embedding,)
torch.onnx.ops.rotary_embedding(X, cos_cache, sin_cache, position_ids=None, *, interleaved=False, num_heads=0, rotary_embedding_dim=0)[source]#

ONNX 中的 RotaryEmbedding 算子。

https://onnx.org.cn/onnx/operators/onnx__RotaryEmbedding.html

RotaryEmbedding 是基于论文 https://arxiv.org/pdf/2104.09864 实现的旋转位置嵌入 (RoPE)。RoPE 的主要优点是它允许模型理解 token 的绝对位置和 token 之间的相对距离。这是通过一种旋转机制实现的,其中旋转的程度是根据 token 的绝对位置 (position_ids) 计算的。嵌入向量被分成两半,或者每隔一个 token 进行交错,然后将旋转矩阵应用于嵌入向量的每个半部分。旋转矩阵由序列中 token 的位置参数化。将嵌入向量的旋转半部分连接起来,形成每个 token 的最终位置嵌入。在自注意力机制中使用旋转的位置嵌入。旋转确保模型能够捕获绝对和相对位置信息。

旋转机制由正弦和余弦函数定义,用于表示旋转角度。对于序列中的每个 token,其位置嵌入是通过旋转其嵌入向量来计算的。这可以通过将嵌入向量分成两半,或者交错每隔一个 token 来实现,然后将旋转矩阵应用于嵌入向量的每个半部分。旋转矩阵由序列中 token 的位置参数化。将旋转后的嵌入向量的半部分连接起来,形成每个 token 的最终位置嵌入。在自注意力机制中使用旋转后的位置嵌入。这种旋转确保了模型捕获了绝对和相对位置信息。

参数
  • X (torch.Tensor) – 表示 token 嵌入的输入张量。形状为 (batch_size, num_heads, sequence_length, head_size) 的 4D 张量,或形状为 (batch_size, sequence_length, hidden_size) 的 3D 张量。对于 4D 输入张量,head_size 必须是偶数。对于 3D 输入张量,必须提供 num_heads 属性,并且 hidden_size 必须是 num_heads 的偶数倍,其中 hidden_size = num_heads * head_size

  • cos_cache (torch.Tensor) – 旋转的余弦值。对于完整旋转,形状为 (max_position_id_plus_1, head_size / 2) 的二维张量;对于部分旋转且提供了 position_ids 时,形状为 (max_position_id_plus_1, rotary_embedding_dim / 2) 的二维张量。对于完整旋转,形状为 (batch_size, sequence_length, head_size / 2) 的三维张量;对于部分旋转且未提供 position_ids 时,形状为 (batch_size, sequence_length, rotary_embedding_dim / 2) 的三维张量。max_position_id_plus_1 是模型的参数。

  • sin_cache (torch.Tensor) – 旋转的正弦值。对于完整旋转,形状为 (max_position_id_plus_1, head_size / 2) 的二维张量;对于部分旋转且提供了 position_ids 时,形状为 (max_position_id_plus_1, rotary_embedding_dim / 2) 的二维张量。对于完整旋转,形状为 (batch_size, sequence_length, head_size / 2) 的三维张量;对于部分旋转且未提供 position_ids 时,形状为 (batch_size, sequence_length, rotary_embedding_dim / 2) 的三维张量。max_position_id_plus_1 是模型的参数。

  • position_ids (torch.Tensor | None) – 词元的 positional 索引。形状为 (batch_size, sequence_length) 的二维张量。

  • interleaved (bool) – 使用交错模式进行旋转。默认值为 0 (False)。

  • num_heads (int) – 注意力头的数量。当输入是三维张量时必须提供。

  • rotary_embedding_dim (int) – 应用部分旋转嵌入所使用的旋转嵌入维度。

返回

与输入形状相同的张量。

返回类型

torch.Tensor

torch.onnx.ops.attention(Q, K, V, attn_mask=None, past_key=None, past_value=None, *, is_causal=False, kv_num_heads=0, q_num_heads=0, qk_matmul_output_mode=0, scale=None, softcap=0.0, softmax_precision=None)[来源]#

ONNX 中的 Attention 操作。

https://onnx.org.cn/onnx/operators/onnx__Attention.html

在可选的注意力掩码(如果提供)下,计算查询、键和值张量上的缩放点积注意力。

此运算符根据 K、Q 和 V 的序列长度,涵盖了注意力操作的自注意力和交叉注意力变体。

对于自注意力,kv_sequence_length 等于 q_sequence_length

对于交叉注意力,查询和键可能具有不同的长度。

此运算符还涵盖了基于头数量的以下 3 种变体:

  1. 多头注意力 (MHA):在论文 https://arxiv.org/pdf/1706.03762 中所述,q_num_heads = kv_num_heads

  2. 分组查询注意力 (GQA):在论文 https://arxiv.org/pdf/2305.13245 中所述,q_num_heads > kv_num_headsq_num_heads % kv_num_heads == 0

  3. 多查询注意力 (MQA):在论文 https://arxiv.org/pdf/1911.02150 中所述,q_num_heads > kv_num_headskv_num_heads=1

注意力偏置的添加是基于 attn_mask 输入和 is_causal 属性计算的,这两个参数只能提供其中一个。

  1. 如果将 is_causal 设置为 1,则当掩码为方阵时,注意力掩码是下三角矩阵。由于对齐,注意力掩码具有左上角因果偏置的形式。

  2. attn_mask:一个布尔掩码,其中 True 值表示该元素应参与注意力;或者一个与查询、键、值相同类型的浮点掩码,该掩码被添加到注意力分数中。

过去和现在的键/值状态都是可选的。它们应一起使用,不允许只使用其中一个。在根据提供的序列长度和头数对 K 和 V 输入进行适当重塑后,将对 Q、K 和 V 输入应用以下模式:

The following pattern is applied by this operator:
        Q          K          V
        |          |          |
Q*sqrt(scale) K*sqrt(scale) |
        |          |          |
        |       Transpose     |
        |          |          |
        ---MatMul---          |
            |               |
at_mask---Add              |
            |               |
    softcap (if provided)     |
            |               |
        Softmax            |
            |               |
            -----MatMul------
                    |
                    Y
参数
  • Q (torch.Tensor) – 查询张量。形状为 (batch_size, q_num_heads, q_sequence_length, head_size) 的四维张量,或形状为 (batch_size, q_sequence_length, q_hidden_size) 的三维张量。对于三维输入张量的情况,q_hidden_size = q_num_heads * head_size

  • K (torch.Tensor) – 键张量。形状为 (batch_size, kv_num_heads, kv_sequence_length, head_size) 的四维张量,或形状为 (batch_size, kv_sequence_length, k_hidden_size) 的三维张量。对于三维输入张量的情况,k_hidden_size = kv_num_heads * head_size

  • V (torch.Tensor) – 值张量。形状为 (batch_size, kv_num_heads, kv_sequence_length, v_head_size) 的四维张量,或形状为 (batch_size, kv_sequence_length, v_hidden_size) 的三维张量。对于三维输入张量的情况,v_hidden_size = kv_num_heads * v_head_size

  • attn_mask (torch.Tensor | None) – 注意力掩码。形状必须可广播到形状为 (batch_size, q_num_heads, q_sequence_length, total_sequence_length) 的四维张量,其中 total_sequence_length = past_sequence_length + kv_sequence_length。支持两种类型的掩码。一种是布尔掩码,其中 True 值表示该元素应参与注意力。还支持与查询、键、值相同类型的浮点掩码,该掩码被添加到注意力分数中。

  • past_key (torch.Tensor | None) – 键的过去状态缓存,形状为 (batch_size, kv_num_heads, past_sequence_length, head_size)

  • past_value (torch.Tensor | None) – 值的过去状态缓存,形状为 (batch_size, kv_num_heads, past_sequence_length, v_head_size)

  • is_causal (bool) – 如果设置为 True,则当掩码为方阵时,注意力掩码是下三角矩阵。由于对齐,注意力掩码具有左上角因果偏置的形式。

  • kv_num_heads (int) – 键和值的头数。必须与 Q、K 和 V 的三维输入一起使用。

  • q_num_heads (int) – 查询的头数。必须与 Q、K 和 V 的三维输入一起使用。

  • qk_matmul_output_mode (int) – 如果设置为 0,则 qk_matmul_output 是 qk 矩阵乘法的输出。如果设置为 1,则 qk_matmul_output 包括将注意力掩码添加到 qk 矩阵乘法输出。如果设置为 2,则 qk_matmul_output 是 softcap 操作后的输出。如果设置为 3,则 qk_matmul_output 是 softmax 操作后的输出。默认值为 0。

  • scale (float | None) – 应用于 Q*K^T 的缩放因子。默认值为 1/sqrt(head_size)。为防止数值溢出,请在矩阵乘法前将 Q、K 乘以 sqrt(scale)。

  • softcap (float) – 注意力权重的 softcap 值。默认值为 0。

  • softmax_precision (int | None) – 在 softmax 计算中使用的浮点精度。如果未提供 softmax 精度,则使用 softmax 输入(Q 和 K)的相同精度。

返回

  • 输出张量。形状为 (batch_size, q_num_heads, q_sequence_length, v_head_size) 的四维张量,或形状为 (batch_size, q_sequence_length, hidden_size) 的三维张量。对于三维输入张量的情况,hidden_size = q_num_heads * v_head_size

  • 更新后的键缓存,形状为 (batch_size, kv_num_heads, total_sequence_length, head_size),其中 total_sequence_length = past_sequence_length + kv_sequence_length

  • 更新后的值缓存,形状为 (batch_size, kv_num_heads, total_sequence_length, v_head_size),其中 total_sequence_length = past_sequence_length + kv_sequence_length

  • QK 矩阵乘法的输出。形状为 (batch_size, q_num_heads, q_sequence_length, total_sequence_length) 的四维张量,其中 total_sequence_length = past_sequence_length + kv_sequence_length

返回类型

一个包含的元组

ONNX 到 ATen 分解表#

您可以使用 torch.onnx.ops.aten_decompositions() 来获取一个分解表,将上面定义的 ONNX 运算符分解为 ATen 运算符。

class Model(torch.nn.Module):
    def forward(
        self, input_data, cos_cache_data, sin_cache_data, position_ids_data
    ):
        return torch.onnx.ops.rotary_embedding(
            input_data,
            cos_cache_data,
            sin_cache_data,
            position_ids_data,
        )

model = Model()

ep = torch.export.export(
    model,
    (input_data, cos_cache_data, sin_cache_data, position_ids_data),
)
# The program can be decomposed into aten ops
ep_decomposed = ep.run_decompositions(torch.onnx.ops.aten_decompositions())
torch.onnx.ops.aten_decompositions()[来源]#

返回 ONNX 到 ATen 的分解表。

返回类型

dict[torch._ops.OpOverload, Callable]