torch.jit.annotate#
- torch.jit.annotate(the_type, the_value)[source]#
用于向 TorchScript 编译器提供 the_value 的类型提示。
此方法是一个传递函数,它返回 the_value,用于向 TorchScript 编译器提示 the_value 的类型。在 TorchScript 外部运行时,它是一个 no-op。
尽管 TorchScript 可以为大多数 Python 表达式推断正确的类型,但有些情况下的类型推断可能不正确,包括:
空容器,例如 [] 和 {},TorchScript 会将它们假定为 Tensor 的容器。
可选类型,例如 Optional[T],但分配了一个有效的 T 类型值,TorchScript 会将其假定为 T 类型,而不是 Optional[T]。
请注意,annotate() 不会帮助处理 torch.nn.Module 子类的 __init__ 方法,因为它是在 eager 模式下执行的。要注释 torch.nn.Module 属性的类型,请改用
Attribute()
。示例
import torch from typing import Dict @torch.jit.script def fn(): # Telling TorchScript that this empty dictionary is a (str -> int) dictionary # instead of default dictionary type of (str -> Tensor). d = torch.jit.annotate(Dict[str, int], {}) # Without `torch.jit.annotate` above, following statement would fail because of # type mismatch. d["name"] = 20
- 参数
the_type – 应该作为类型提示传递给 TorchScript 编译器的 Python 类型,用于 the_value。
the_value – 用于提示类型的(值或表达式)。
- 返回
the_value 被返回作为返回值。