评价此页

控制流 - Cond#

创建于: 2023年10月03日 | 最后更新于: 2025年06月13日

torch.cond 是一个结构化控制流算子。它可以用于指定 if-else 类的控制流,并且在逻辑上可以看作是如下实现的。

def cond(
    pred: Union[bool, torch.Tensor],
    true_fn: Callable,
    false_fn: Callable,
    operands: Tuple[torch.Tensor]
):
    if pred:
        return true_fn(*operands)
    else:
        return false_fn(*operands)

它独特的力量在于其表达 **数据依赖型控制流** 的能力:它会降低为一个条件算子(torch.ops.higher_order.cond),该算子保留了谓词、真函数和假函数。这极大地提高了编写和部署那些根据张量操作的输入或中间输出的 **值** 或 **形状** 来改变模型架构的模型所带来的灵活性。

警告

torch.cond 是 PyTorch 中的一个原型功能。它对输入和输出类型支持有限,并且目前不支持训练。请期待 PyTorch 未来版本中更稳定的实现。有关功能分类的更多信息,请参阅:https://pytorch.ac.cn/blog/pytorch-feature-classification-changes/#prototype

示例#

下面是一个使用 cond 根据输入形状进行分支的示例

    import torch

    def true_fn(x: torch.Tensor):
        return x.cos() + x.sin()

    def false_fn(x: torch.Tensor):
        return x.sin()

    class DynamicShapeCondPredicate(torch.nn.Module):
        """
        A basic usage of cond based on dynamic shape predicate.
        """

        def __init__(self):
            super().__init__()

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            def true_fn(x: torch.Tensor):
                return x.cos()

            def false_fn(x: torch.Tensor):
                return x.sin()

            return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))

    dyn_shape_mod = DynamicShapeCondPredicate()

我们可以立即运行模型,并期望结果根据输入形状而变化

    inp = torch.randn(3)
    inp2 = torch.randn(5)
    assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
    assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))

我们可以导出模型以进行进一步的转换和部署

    inp = torch.randn(4, 3)
    dim_batch = torch.export.Dim("batch", min=2)
    ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
    print(ep)

这将为我们提供一个导出的程序,如下所示

    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
            gt: Sym(s0 > 4) = sym_size > 4;  sym_size = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
            return (conditional,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
                return add

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return sin

请注意,torch.cond 被降低为 torch.ops.higher_order.cond,其谓词成为输入形状上的符号表达式,分支函数成为顶级图模块的两个子图属性。

这是另一个示例,展示了如何表达数据依赖型控制流

    class DataDependentCondPredicate(torch.nn.Module):
        """
        A basic usage of cond based on data dependent predicate.
        """
        def __init__(self):
            super().__init__()

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))

导出后我们得到的导出程序

    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
            gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0);  sum_1 = None

            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
            return (conditional,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
                return add

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return sin

torch.ops.higher_order.cond 的不变式#

对于 torch.ops.higher_order.cond 有几个有用的不变式

  • 对于谓词

    • 谓词的动态性被保留(例如,上面示例中所示的 gt

    • 如果用户程序中的谓词是常量(例如,一个 Python 的 bool 常量),则算子的 pred 将是一个常量。

  • 对于分支

    • 输入和输出签名将是一个展平的元组。

    • 它们是 torch.fx.GraphModule

    • 原始函数中的闭包成为显式输入。没有闭包。

    • 不允许对输入或全局变量进行修改。

  • 对于操作数

    • 它也将是一个扁平的元组。

  • 用户程序中 torch.cond 的嵌套将成为嵌套的图模块。

API 参考#

torch._higher_order_ops.cond.cond(pred, true_fn, false_fn, operands=())[source]#

有条件地应用 true_fnfalse_fn

警告

torch.cond 是 PyTorch 中的一个原型功能。它对输入和输出类型支持有限,并且目前不支持训练。请期待 PyTorch 未来版本中更稳定的实现。有关功能分类的更多信息,请参阅:https://pytorch.ac.cn/blog/pytorch-feature-classification-changes/#prototype

cond 是结构化控制流算子。也就是说,它类似于 Python 的 if 语句,但对 true_fnfalse_fnoperands 有限制,这些限制使其能够被 torch.compile 和 torch.export 捕获。

假设满足 cond 参数的约束条件,cond 等价于以下内容

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)
参数
  • pred (Union[bool, torch.Tensor]) – 一个布尔表达式或一个只有一个元素的张量,指示应用哪个分支函数。

  • true_fn (Callable) – 一个可调用函数(a -> b),它在被跟踪的范围内。

  • false_fn (Callable) – 一个可调用函数(a -> b),它在被跟踪的范围内。真分支和假分支必须具有一致的输入和输出,这意味着输入必须相同,输出必须是相同的类型和形状。也允许 int 输出。我们将通过将其转换为 symint 来使输出动态化。

  • operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – 一个输入真/假函数的元组。如果 true_fn/false_fn 不需要输入,则可以为空。默认为 ()。

返回类型

任何

示例

def true_fn(x: torch.Tensor):
    return x.cos()


def false_fn(x: torch.Tensor):
    return x.sin()


return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
限制
  • 条件语句(又名 pred)必须满足以下任一约束:

    • 它是一个只有一个元素的 torch.Tensor,且 dtype 为 torch.bool

    • 它是一个布尔表达式,例如 x.shape[0] > 10x.dim() > 1 and x.shape[1] > 10

  • 分支函数(又名 true_fn/false_fn)必须满足以下所有约束:

    • 函数签名必须与操作数匹配。

    • 函数必须返回具有相同元数据(例如,形状、dtype 等)的张量。

    • 函数不能对输入或全局变量进行原地修改。(注意:像 add_ 这样的原地张量操作用于中间结果是被允许在分支中使用的)