torch.jit.annotate#
- torch.jit.annotate(the_type, the_value)[source]#
用于在 TorchScript 编译器中指定 the_value 的类型。
此方法是一个传递函数,返回 the_value,用于向 TorchScript 编译器提示 the_value 的类型。在 TorchScript 外部运行时,它是一个空操作。
尽管 TorchScript 可以推断大多数 Python 表达式的正确类型,但在某些情况下,类型推断可能不正确,包括:
空容器,如 [] 和 {},TorchScript 会假定它们是 Tensor 的容器。
可选类型,如 Optional[T],但被赋值了类型为 T 的有效值。TorchScript 会假定其类型为 T 而非 Optional[T]。
请注意,annotate() 对 torch.nn.Module 子类的 __init__ 方法没有帮助,因为它是在 eager 模式下执行的。要注释 torch.nn.Module 属性的类型,请使用
Attribute()。示例
import torch from typing import Dict @torch.jit.script def fn(): # Telling TorchScript that this empty dictionary is a (str -> int) dictionary # instead of default dictionary type of (str -> Tensor). d = torch.jit.annotate(Dict[str, int], {}) # Without `torch.jit.annotate` above, following statement would fail because of # type mismatch. d["name"] = 20
- 参数
the_type – 应传递给 TorchScript 编译器的 Python 类型,作为 the_value 的类型提示。
the_value – 要提示其类型的数值或表达式。
- 返回
the_value 将作为返回值返回。