评价此页

torch.jit.annotate#

torch.jit.annotate(the_type, the_value)[source]#

用于向 TorchScript 编译器提供 the_value 的类型提示。

此方法是一个传递函数,它返回 the_value,用于向 TorchScript 编译器提示 the_value 的类型。在 TorchScript 外部运行时,它是一个 no-op。

尽管 TorchScript 可以为大多数 Python 表达式推断正确的类型,但有些情况下的类型推断可能不正确,包括:

  • 空容器,例如 []{},TorchScript 会将它们假定为 Tensor 的容器。

  • 可选类型,例如 Optional[T],但分配了一个有效的 T 类型值,TorchScript 会将其假定为 T 类型,而不是 Optional[T]

请注意,annotate() 不会帮助处理 torch.nn.Module 子类的 __init__ 方法,因为它是在 eager 模式下执行的。要注释 torch.nn.Module 属性的类型,请改用 Attribute()

示例

import torch
from typing import Dict

@torch.jit.script
def fn():
    # Telling TorchScript that this empty dictionary is a (str -> int) dictionary
    # instead of default dictionary type of (str -> Tensor).
    d = torch.jit.annotate(Dict[str, int], {})

    # Without `torch.jit.annotate` above, following statement would fail because of
    # type mismatch.
    d["name"] = 20
参数
  • the_type – 应该作为类型提示传递给 TorchScript 编译器的 Python 类型,用于 the_value

  • the_value – 用于提示类型的(值或表达式)。

返回

the_value 被返回作为返回值。