评价此页

torch.cond#

torch.cond(pred, true_fn, false_fn, operands=())[源码]#

有条件地应用 true_fnfalse_fn

警告

torch.cond 是 PyTorch 中的一个原型(prototype)功能。它对输入和输出类型有有限的支持,目前不支持训练。请期待 PyTorch 未来版本中更稳定的实现。有关功能分类的更多信息,请参阅:https://pytorch.ac.cn/blog/pytorch-feature-classification-changes/#prototype

cond 是一个结构化控制流算子。也就是说,它类似于 Python 的 if 语句,但在 true_fnfalse_fnoperands 上有限制,这使得它可以使用 torch.compile 和 torch.export 进行捕获。

假设 cond 参数的约束条件已满足,cond 等价于以下代码:

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)
参数
  • pred (Union[bool, torch.Tensor]) – 一个布尔表达式或一个只有一个元素的张量,指示应用哪个分支函数。

  • true_fn (Callable) – 一个在正在追踪的范围内可调用的函数(a -> b)。

  • false_fn (Callable) – 一个在正在追踪的范围内可调用的函数(a -> b)。真分支和假分支必须具有一致的输入和输出,这意味着输入必须相同,输出必须是相同的类型和形状。也允许 int 输出。我们将通过将其转换为 symint 来使输出动态化。

  • operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – 输入到 true/false 函数的元组。如果 true_fn/false_fn 不需要输入,则可以为空。默认为 ()。

返回类型

任何

示例

def true_fn(x: torch.Tensor):
    return x.cos()


def false_fn(x: torch.Tensor):
    return x.sin()


return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
约束条件
  • 条件语句(也称为 pred)必须满足以下任一约束条件:

    • 它是一个只有一个元素的 torch.Tensor,且 dtype 为 torch.bool。

    • 它是一个布尔表达式,例如 x.shape[0] > 10x.dim() > 1 and x.shape[1] > 10

  • 分支函数(也称为 true_fn/false_fn)必须满足以下所有约束条件:

    • 函数签名必须与 operands 匹配。

    • 函数必须返回一个具有相同元数据(例如,形状、dtype 等)的张量。

    • 函数不能对输入或全局变量进行就地(in-place)修改。(注意:对于中间结果的就地张量操作,如 add_,在分支中是允许的)。