TorchScript 语言参考#
创建于:2021年3月10日 | 最后更新:2025年6月13日
本参考手册描述了 TorchScript 语言的语法和核心语义。TorchScript 是 Python 语言的一个静态类型子集。本文档解释了 TorchScript 中支持的 Python 特性,以及该语言与常规 Python 的区别。本参考手册未提及的 Python 特性均不属于 TorchScript。TorchScript 专注于 PyTorch 中用于表示神经网络模型的 Python 特性。
术语#
本文档使用以下术语
Pattern |
注意事项 |
---|---|
|
表示给定符号定义为。 |
|
表示属于语法一部分的真实关键字和分隔符。 |
|
表示 A 或 B。 |
|
表示分组。 |
|
表示可选。 |
|
表示一个正则表达式,其中术语 A 重复至少一次。 |
|
表示一个正则表达式,其中术语 A 重复零次或多次。 |
类型系统#
TorchScript 是 Python 的一个静态类型子集。TorchScript 和完整的 Python 语言之间最大的区别在于,TorchScript 仅支持用于表达神经网络模型的一小部分类型。
TorchScript 类型#
TorchScript 类型系统由下文定义的 `TSType
` 和 `TSModuleType
` 组成。
TSAllType ::= TSType | TSModuleType
TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
`TSType
` 代表了 TorchScript 大部分可组合且可在 TorchScript 类型注解中使用的类型。`TSType
` 指代以下任何一项
Meta 类型,例如 `
Any
`基本类型,例如 `
int
`、`float
` 和 `str
`结构化类型,例如 `
Optional[int]
` 或 `List[MyClass]
`名义类型(Python 类),例如 `
MyClass
`(用户定义)、`torch.tensor
`(内置)
`TSModuleType
` 代表 `torch.nn.Module
` 及其子类。由于其类型模式部分从对象实例推断,部分从类定义推断,因此 `TSModuleType
` 的处理方式与 `TSType
` 不同。因此,`TSModuleType
` 的实例可能不遵循相同的静态类型模式。出于类型安全考虑,`TSModuleType
` 不能用作 TorchScript 类型注解,也不能与 `TSType
` 组合。
Meta 类型#
Meta 类型非常抽象,更像是类型约束而不是具体类型。目前 TorchScript 定义了一个 meta 类型 `Any
`,它代表任何 TorchScript 类型。
`Any
` 类型#
`Any
` 类型代表任何 TorchScript 类型。`Any
` 不指定任何类型约束,因此对 `Any
` 没有类型检查。因此,它可以绑定到任何 Python 或 TorchScript 数据类型(例如,`int
`、TorchScript `tuple
`,或未被脚本化的任意 Python 类)。
TSMetaType ::= "Any"
其中
`
Any
` 是 `typing` 模块中的 Python 类名。因此,要使用 `Any
` 类型,您必须从 `typing
` 中导入它(例如,`from typing import Any
`)。由于 `
Any
` 可以代表任何 TorchScript 类型,因此可以在 `Any
` 类型的数据上操作的运算符集是有限的。
支持 `Any
` 类型的运算符#
赋值给 `
Any
` 类型的数据。绑定到 `
Any
` 类型的参数或返回值。`
x is
`、`x is not
`,其中 `x
` 是 `Any
` 类型。`
isinstance(x, Type)
`,其中 `x
` 是 `Any
` 类型。`
Any
` 类型的数据是可打印的。`
List[Any]
` 类型的数据如果是相同类型 `T
` 的值列表,并且 `T
` 支持比较运算符,那么它可以是可排序的。
与 Python 相比
`Any
` 是 TorchScript 类型系统中约束最少的类型。从这个意义上说,它与 Python 中的 `Object
` 类非常相似。但是,`Any
` 只支持 `Object
` 支持的运算符和方法的子集。
设计说明#
当我们对 PyTorch 模块进行脚本化时,可能会遇到不参与脚本执行的数据。尽管如此,它仍必须由类型模式来描述。为未使用的数据(在脚本的上下文中)描述静态类型不仅很麻烦,而且可能导致不必要的脚本化失败。引入 `Any
` 是为了描述那些对于编译来说不需要精确静态类型的 <$> 数据的类型。
示例 1
此示例说明了如何使用 `Any
` 来允许元组参数的第二个元素为任何类型。这是可能的,因为 `x[1]
` 不涉及任何需要了解其精确类型的计算。
import torch
from typing import Tuple
from typing import Any
@torch.jit.export
def inc_first_element(x: Tuple[int, Any]):
return (x[0]+1, x[1])
m = torch.jit.script(inc_first_element)
print(m((1,2.0)))
print(m((1,(100,200))))
上面的示例产生以下输出
(2, 2.0)
(2, (100, 200))
元组的第二个元素是 `Any
` 类型,因此可以绑定到多种类型。例如,`(1, 2.0)
` 将浮点类型绑定到 `Any
`,如 `Tuple[int, Any]
`,而 `(1, (100, 200))
` 在第二次调用中将元组绑定到 `Any
`。
示例 2
此示例说明了我们如何使用 `isinstance
` 来动态检查被注解为 `Any
` 类型的数据的类型
import torch
from typing import Any
def f(a:Any):
print(a)
return (isinstance(a, torch.Tensor))
ones = torch.ones([2])
m = torch.jit.script(f)
print(m(ones))
上面的示例产生以下输出
1
1
[ CPUFloatType{2} ]
True
基本类型#
基本 TorchScript 类型是表示单一类型值的类型,并带有单一预定义的类型名称。
TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
结构化类型#
结构化类型是那些在结构上定义但没有用户定义名称(与名义类型不同)的类型,例如 `Future[int]
`。结构化类型可以与任何 `TSType
` 组合。
TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict |
TSOptional | TSUnion | TSFuture | TSRRef | TSAwait
TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]"
TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")"
TSList ::= "List" "[" TSType "]"
TSOptional ::= "Optional" "[" TSType "]"
TSUnion ::= "Union" "[" (TSType ",")* TSType "]"
TSFuture ::= "Future" "[" TSType "]"
TSRRef ::= "RRef" "[" TSType "]"
TSAwait ::= "Await" "[" TSType "]"
TSDict ::= "Dict" "[" KeyType "," TSType "]"
KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
其中
`
Tuple
`、`List
`、`Optional
`、`Union
`、`Future
`、`Dict
` 代表 `typing` 模块中定义的 Python 类型类名。要使用这些类型名称,您必须从 `typing
` 中导入它们(例如,`from typing import Tuple
`)。`
namedtuple
` 代表 Python 类 `collections.namedtuple
` 或 `typing.NamedTuple
`。`
Future
` 和 `RRef
` 分别代表 Python 类 `torch.futures
` 和 `torch.distributed.rpc
`。`
Await
` 代表 Python 类 `torch._awaits._Await
`
与 Python 相比
除了可以与 TorchScript 类型组合外,这些 TorchScript 结构化类型通常支持其 Python 对等项的常见运算符和方法子集。
示例 1
此示例使用 `typing.NamedTuple
` 语法定义一个元组
import torch
from typing import NamedTuple
from typing import Tuple
class MyTuple(NamedTuple):
first: int
second: int
def inc(x: MyTuple) -> Tuple[int, int]:
return (x.first+1, x.second+1)
t = MyTuple(first=1, second=2)
scripted_inc = torch.jit.script(inc)
print("TorchScript:", scripted_inc(t))
上面的示例产生以下输出
TorchScript: (2, 3)
示例 2
此示例使用 `collections.namedtuple
` 语法定义一个元组
import torch
from typing import NamedTuple
from typing import Tuple
from collections import namedtuple
_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)])
_UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second'])
def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]:
return (x.first+1, x.second+1)
m = torch.jit.script(inc)
print(inc(_UnannotatedNamedTuple(1,2)))
上面的示例产生以下输出
(2, 3)
示例 3
此示例说明了注解结构化类型时的一个常见错误,即未从 `typing
` 模块导入组合类型类
import torch
# ERROR: Tuple not recognized because not imported from typing
@torch.jit.export
def inc(x: Tuple[int, int]):
return (x[0]+1, x[1]+1)
m = torch.jit.script(inc)
print(m((1,2)))
运行上述代码会产生以下脚本错误
File "test-tuple.py", line 5, in <module>
def inc(x: Tuple[int, int]):
NameError: name 'Tuple' is not defined
解决方案是,在代码开头添加一行 from typing import Tuple
。
标称类型#
TorchScript 的标称类型是 Python 类。这些类型被称为标称类型,因为它们是用自定义名称声明的,并且通过类名进行比较。标称类进一步分为以下几类:
TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
其中,TSCustomClass
和 TSEnum
必须可编译为 TorchScript 中间表示 (IR)。这由类型检查器强制执行。
内置类#
内置标称类型是语义已内置到 TorchScript 系统中的 Python 类(例如,张量类型)。TorchScript 定义了这些内置标称类型的语义,并且通常只支持其 Python 类定义的某些方法或属性的子集。
TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" |
"torch.nn.ModuleList" | "torch.nn.ModuleDict" | ...
TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" |
"torch.nn.parameter.Parameter" | and subclasses of torch.Tensor
关于 torch.nn.ModuleList 和 torch.nn.ModuleDict 的特别说明#
尽管 torch.nn.ModuleList
和 torch.nn.ModuleDict
在 Python 中被定义为列表和字典,但在 TorchScript 中它们的行为更像元组。
在 TorchScript 中,
torch.nn.ModuleList
或torch.nn.ModuleDict
的实例是不可变的。遍历
torch.nn.ModuleList
或torch.nn.ModuleDict
的代码会被完全展开,以便torch.nn.ModuleList
的元素或torch.nn.ModuleDict
的键可以是torch.nn.Module
的不同子类。
示例
以下示例重点介绍了几个内置 TorchScript 类(torch.*
)的用法。
import torch
@torch.jit.script
class A:
def __init__(self):
self.x = torch.rand(3)
def f(self, y: torch.device):
return self.x.to(device=y)
def g():
a = A()
return a.f(torch.device("cpu"))
script_g = torch.jit.script(g)
print(script_g.graph)
自定义类#
与内置类不同,自定义类的语义由用户定义,并且整个类定义必须可编译为 TorchScript IR,并遵循 TorchScript 的类型检查规则。
TSClassDef ::= [ "@torch.jit.script" ]
"class" ClassName [ "(object)" ] ":"
MethodDefinition |
[ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ]
MethodDefinition
其中
类必须是新式类。Python 3 只支持新式类。在 Python 2.x 中,通过继承 object 来指定新式类。
实例数据属性是静态类型的,实例属性必须在
__init__()
方法内通过赋值来声明。不支持方法重载(即,不能有同名方法)。
MethodDefinition
必须可编译为 TorchScript IR 并遵循 TorchScript 的类型检查规则(即,所有方法都必须是有效的 TorchScript 函数,类属性定义必须是有效的 TorchScript 语句)。torch.jit.ignore
和torch.jit.unused
可用于忽略未完全 TorchScript 化或应被编译器忽略的方法或函数。
与 Python 相比
与 Python 对应物相比,TorchScript 的自定义类相当有限。TorchScript 自定义类:
不支持类属性。
不支持子类化,除非子类化接口类型或对象。
不支持方法重载。
必须在
__init__()
中初始化其所有实例属性;这是因为 TorchScript 通过在__init__()
中推断属性类型来构建类的静态架构。必须只包含满足 TorchScript 类型检查规则且可编译为 TorchScript IR 的方法。
示例 1
Python 类可以用于 TorchScript,前提是它们被注解为 @torch.jit.script
,这与声明 TorchScript 函数的方式类似。
@torch.jit.script
class MyClass:
def __init__(self, x: int):
self.x = x
def inc(self, val: int):
self.x += val
示例 2
TorchScript 自定义类类型必须通过在 __init__()
中的赋值来“声明”其所有实例属性。如果一个实例属性未在 __init__()
中定义,却在类的其他方法中访问,则该类无法编译为 TorchScript 类,如下例所示:
import torch
@torch.jit.script
class foo:
def __init__(self):
self.y = 1
# ERROR: self.x is not defined in __init__
def assign_x(self):
self.x = torch.rand(2, 3)
该类将编译失败并报错如下:
RuntimeError:
Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
def assign_x(self):
self.x = torch.rand(2, 3)
~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
示例 3
在此示例中,TorchScript 自定义类定义了一个类变量 name,这是不允许的。
import torch
@torch.jit.script
class MyClass(object):
name = "MyClass"
def __init__(self, x: int):
self.x = x
def fn(a: MyClass):
return a.name
这会导致以下编译时错误:
RuntimeError:
'__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?:
File "test-class2.py", line 10
def fn(a: MyClass):
return a.name
~~~~~~ <--- HERE
枚举类型#
与自定义类一样,枚举类型的语义由用户定义,整个类定义必须可编译为 TorchScript IR 并遵循 TorchScript 类型检查规则。
TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":"
( MemberIdentifier "=" Value )+
( MethodDefinition )*
其中
值必须是
int
、float
或str
类型的 TorchScript 字面量,并且必须是相同的 TorchScript 类型。TSEnumType
是 TorchScript 枚举类型的名称。与 Python 枚举类似,TorchScript 允许受限的Enum
子类化,也就是说,只有在枚举类不定义任何成员时才允许子类化。
与 Python 相比
TorchScript 只支持
enum.Enum
。它不支持其他变体,如enum.IntEnum
、enum.Flag
、enum.IntFlag
和enum.auto
。TorchScript 枚举成员的值必须是相同类型,并且只能是
int
、float
或str
类型,而 Python 枚举成员可以是任何类型。包含方法的枚举在 TorchScript 中会被忽略。
示例 1
以下示例将 Color
类定义为 Enum
类型。
import torch
from enum import Enum
class Color(Enum):
RED = 1
GREEN = 2
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
m = torch.jit.script(enum_fn)
print("Eager: ", enum_fn(Color.RED, Color.GREEN))
print("TorchScript: ", m(Color.RED, Color.GREEN))
示例 2
以下示例展示了受限的枚举子类化情况,其中 BaseColor
未定义任何成员,因此可以被 Color
子类化。
import torch
from enum import Enum
class BaseColor(Enum):
def foo(self):
pass
class Color(BaseColor):
RED = 1
GREEN = 2
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
m = torch.jit.script(enum_fn)
print("TorchScript: ", m(Color.RED, Color.GREEN))
print("Eager: ", enum_fn(Color.RED, Color.GREEN))
TorchScript 模块类#
TSModuleType
是一种特殊的类类型,它是从 TorchScript 外部创建的对象实例推断出来的。TSModuleType
的名称来自对象实例的 Python 类。__init__()
方法不被视为 TorchScript 方法,因此不必遵守 TorchScript 的类型检查规则。
模块实例类的类型架构直接从实例对象(在 TorchScript 范围外创建)构建,而不是像自定义类那样从 __init__()
推断。同一实例类类型的两个对象可能遵循不同的类型架构。
从这个意义上说,TSModuleType
并不是真正的静态类型。因此,出于类型安全考虑,TSModuleType
不能用于 TorchScript 类型注解,也不能与 TSType
组合。
模块实例类#
TorchScript 模块类型表示用户定义的 PyTorch 模块实例的类型架构。在脚本化 PyTorch 模块时,模块对象始终在 TorchScript 外部创建(即,作为参数传递给 forward
)。Python 模块类被视为模块实例类,因此 Python 模块类的 __init__()
方法不受 TorchScript 类型检查规则的约束。
TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":"
ClassBodyDefinition
其中
forward()
和用@torch.jit.export
注解的其他方法必须可编译为 TorchScript IR 并遵循 TorchScript 的类型检查规则。
与自定义类不同,只有模块类型的 forward
方法和其他用 @torch.jit.export
注解的方法需要可编译。最值得注意的是,__init__()
不被视为 TorchScript 方法。因此,不能在 TorchScript 范围内部调用模块类型构造函数。相反,TorchScript 模块对象始终在外部构造并传递给 torch.jit.script(ModuleObj)
。
示例 1
此示例说明了模块类型的几个特性。
TestModule
实例是在 TorchScript 范围之外创建的(即,在调用torch.jit.script
之前)。__init__()
不被视为 TorchScript 方法,因此无需注解,并且可以包含任意 Python 代码。此外,无法在 TorchScript 代码中调用实例类的__init__()
方法。由于TestModule
实例在 Python 中实例化,在本例中,TestModule(2.0)
和TestModule(2)
创建了两个具有不同类型的数据属性的实例。TestModule(2.0)
的self.x
类型为float
,而TestModule(2.0)
的self.y
类型为int
。TorchScript 会自动编译被
@torch.jit.export
注解的方法或forward()
方法调用的其他方法(例如mul()
)。TorchScript 程序的入口点是模块类型的
forward()
、被注解为torch.jit.script
的函数,或被注解为torch.jit.export
的方法。
import torch
class TestModule(torch.nn.Module):
def __init__(self, v):
super().__init__()
self.x = v
def forward(self, inc: int):
return self.x + inc
m = torch.jit.script(TestModule(1))
print(f"First instance: {m(3)}")
m = torch.jit.script(TestModule(torch.ones([5])))
print(f"Second instance: {m(3)}")
上面的示例产生以下输出
First instance: 4
Second instance: tensor([4., 4., 4., 4., 4.])
示例 2
以下示例展示了模块类型的不正确用法。特别是,该示例在 TorchScript 范围内部调用了 TestModule
的构造函数。
import torch
class TestModule(torch.nn.Module):
def __init__(self, v):
super().__init__()
self.x = v
def forward(self, x: int):
return self.x + x
class MyModel:
def __init__(self, v: int):
self.val = v
@torch.jit.export
def doSomething(self, val: int) -> int:
# error: should not invoke the constructor of module type
myModel = TestModule(self.val)
return myModel(val)
# m = torch.jit.script(MyModel(2)) # Results in below RuntimeError
# RuntimeError: Could not get name of python class object
类型注解#
由于 TorchScript 是静态类型的,程序员需要在 TorchScript 代码的战略点注解类型,以便每个局部变量或实例数据属性都有一个静态类型,并且每个函数和方法都有一个静态类型签名。
何时注解类型#
通常,类型注解仅在静态类型无法自动推断的地方需要(例如,参数或方法的返回类型)。局部变量和数据属性的类型通常可以从其赋值语句自动推断。有时,推断出的类型可能过于严格,例如 x
通过赋值 x = None
被推断为 NoneType
,而 x
实际上用作 Optional
。在这种情况下,可能需要类型注解来覆盖自动推断,例如 x: Optional[int] = None
。请注意,即使局部变量或数据属性的类型可以自动推断,也始终可以对其进行类型注解。注解的类型必须与 TorchScript 的类型检查兼容。
当参数、局部变量或数据属性未进行类型注解且其类型无法自动推断时,TorchScript 假定它们是默认类型 TensorType
、List[TensorType]
或 Dict[str, TensorType]
。
注解函数签名#
由于参数可能无法从函数体(包括函数和方法)自动推断,因此需要对其进行类型注解。否则,它们将假定默认类型 TensorType
。
TorchScript 支持两种风格的方法和函数签名类型注解:
Python3 风格直接在签名上注解类型。因此,它允许单独的参数不被注解(其类型将是
TensorType
的默认类型),或者允许返回类型不被注解(其类型将自动推断)。
Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":"
FuncOrMethodBody
ParamAnnot ::= Identifier [ ":" TSType ] ","
ReturnAnnot ::= "->" TSType
请注意,在使用 Python3 风格时,self
的类型会自动推断,不应被注解。
Mypy 风格将类型作为注释直接写在函数/方法声明的下方。在 Mypy 风格中,由于参数名不会出现在注解中,因此所有参数都必须被注解。
MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ]
ParamAnnot ::= TSType ","
ReturnAnnot ::= "->" TSType
示例 1
在此示例中:
a
未被注解,并假定TensorType
的默认类型。b
被注解为int
类型。返回类型未被注解,并自动推断为
TensorType
(基于返回值的类型)。
import torch
def f(a, b: int):
return a+b
m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))
示例 2
以下示例使用 Mypy 风格注解。请注意,即使某些参数或返回值假定为默认类型,也必须对其进行注解。
import torch
def f(a, b):
# type: (torch.Tensor, int) → torch.Tensor
return a+b
m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))
注解变量和数据属性#
通常,数据属性(包括类和实例数据属性)和局部变量的类型可以从赋值语句自动推断。但是,有时如果变量或属性与不同类型的值相关联(例如,与 None
或 TensorType
),则可能需要将其明确注解为更宽泛的类型,如 Optional[int]
或 Any
。
局部变量#
局部变量可以根据 Python3 typing 模块注解规则进行注解,即:
LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr
通常,局部变量的类型可以自动推断。但在某些情况下,您可能需要为可能与不同具体类型关联的多类型局部变量进行注解。典型的多类型包括 Optional[T]
和 Any
。
示例
import torch
def f(a, setVal: bool):
value: Optional[torch.Tensor] = None
if setVal:
value = a
return value
ones = torch.ones([6])
m = torch.jit.script(f)
print("TorchScript:", m(ones, True), m(ones, False))
实例数据属性#
对于 ModuleType
类,实例数据属性可以根据 Python3 typing 模块注解规则进行注解。实例数据属性可以(可选地)被注解为 final,使用 Final
。
"class" ClassIdentifier "(torch.nn.Module):"
InstanceAttrIdentifier ":" ["Final("] TSType [")"]
...
其中
InstanceAttrIdentifier
是实例属性的名称。Final
表示属性不能在__init__
之外重新赋值或在子类中重写。
示例
import torch
class MyModule(torch.nn.Module):
offset_: int
def __init__(self, offset):
self.offset_ = offset
...
类型注解 API#
torch.jit.annotate(T, expr)
#
此 API 将类型 T
注解到表达式 expr
。当表达式的默认类型不是程序员期望的类型时,通常使用此 API。例如,空列表(字典)的默认类型是 List[TensorType]
(Dict[TensorType, TensorType]
),但有时它可能用于初始化其他类型的列表。另一个常见用例是注解 tensor.tolist()
的返回类型。但是请注意,它不能用于注解 __init__
中模块属性的类型;应改用 torch.jit.Attribute
。
示例
在此示例中,[]
被声明为整数列表,通过 torch.jit.annotate
(而不是假定 []
为 List[TensorType]
的默认类型)。
import torch
from typing import List
def g(l: List[int], val: int):
l.append(val)
return l
def f(val: int):
l = g(torch.jit.annotate(List[int], []), val)
return l
m = torch.jit.script(f)
print("Eager:", f(3))
print("TorchScript:", m(3))
有关更多信息,请参见 torch.jit.annotate()
。
类型注解附录#
TorchScript 类型系统定义#
TSAllType ::= TSType | TSModuleType
TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
TSMetaType ::= "Any"
TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional |
TSUnion | TSFuture | TSRRef | TSAwait
TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]"
TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")"
TSList ::= "List" "[" TSType "]"
TSOptional ::= "Optional" "[" TSType "]"
TSUnion ::= "Union" "[" (TSType ",")* TSType "]"
TSFuture ::= "Future" "[" TSType "]"
TSRRef ::= "RRef" "[" TSType "]"
TSAwait ::= "Await" "[" TSType "]"
TSDict ::= "Dict" "[" KeyType "," TSType "]"
KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"|
"torch.dtype" | "torch.nn.ModuleList" |
"torch.nn.ModuleDict" | ...
TSTensor ::= "torch.tensor" and subclasses
不支持的类型构造#
TorchScript 不支持 Python3 typing 模块的所有特性和类型。 typing 模块中未在此文档中明确说明的任何功能均不支持。下表总结了 TorchScript 中不支持或支持受限的 typing
构造。
项 |
描述 |
|
开发中 |
|
不支持 |
|
不支持 |
|
不支持 |
|
不支持 |
|
支持模块属性、类属性和注解,但不支持函数。 |
|
不支持 |
|
开发中 |
类型别名 |
不支持 |
标称类型 |
开发中 |
结构类型 |
不支持 |
NewType |
不支持 |
泛型 |
不支持 |
表达式#
以下部分描述了 TorchScript 中支持的表达式的语法。它以 Python 语言参考的表达式章节为模型。
算术转换#
TorchScript 中会执行一些隐式类型转换:
具有
float
或int
数据类型的Tensor
可以隐式转换为FloatType
或IntType
的实例,前提是它的尺寸为 0,没有将require_grad
设置为True
,并且不需要缩小。StringType
的实例可以隐式转换为DeviceType
。上述两点中的隐式转换规则可以应用于
TupleType
的实例,以生成具有适当包含类型的ListType
实例。
显式转换可以使用接受原始数据类型作为参数的内置函数 float
、int
、bool
和 str
来调用,如果用户定义类型实现了 __bool__
、__str__
等。
原子#
原子是表达式的最基本元素。
atom ::= identifier | literal | enclosure
enclosure ::= parenth_form | list_display | dict_display
标识符#
规定 TorchScript 中合法标识符的规则与其 Python 对应项相同。
字面量#
literal ::= stringliteral | integer | floatnumber
字面量的求值会产生具有特定值(根据需要对浮点数应用近似值)的适当类型的对象。字面量是不可变的,对相同字面量的多次求值可能获得相同的对象或具有相同值的不同对象。字符串字面量、整数和浮点数定义方式与其 Python 对应项相同。
带括号的表达式#
parenth_form ::= '(' [expression_list] ')'
带括号的表达式列表会产生表达式列表产生的任何内容。如果列表至少包含一个逗号,它会产生一个 Tuple
;否则,它会产生表达式列表中的单个表达式。一对空括号会产生一个空 Tuple
对象(Tuple[]
)。
列表和字典显示#
list_comprehension ::= expression comp_for
comp_for ::= 'for' target_list 'in' or_expr
list_display ::= '[' [expression_list | list_comprehension] ']'
dict_display ::= '{' [key_datum_list | dict_comprehension] '}'
key_datum_list ::= key_datum (',' key_datum)*
key_datum ::= expression ':' expression
dict_comprehension ::= key_datum comp_for
列表和字典可以通过显式列出容器内容或通过一组循环指令(即推导式)提供计算它们的说明来构造。推导式在语义上等同于使用 for 循环并追加到正在进行的列表中。推导式会隐式创建自己的作用域,以确保列表项不会泄露到封闭作用域中。如果容器项被显式列出,则表达式列表从左到右求值。如果在具有 key_datum_list
的 dict_display
中重复了键,则生成的字典将使用列表中使用重复键的最后一个数据的值。
基本项#
primary ::= atom | attributeref | subscription | slicing | call
订阅#
subscription ::= primary '[' expression_list ']'
基本项必须求值为支持订阅的对象。
如果基本项是
List
、Tuple
或str
,则表达式列表必须求值为整数或切片。如果基本项是
Dict
,则表达式列表必须求值为与Dict
的键类型相同的对象。如果基本项是
ModuleList
,则表达式列表必须是整数字面量。如果基本项是
ModuleDict
,则表达式必须是字符串字面量。
切片#
切片选择 str
、Tuple
、List
或 Tensor
中的一系列项。切片可用作赋值或 del
语句中的表达式或目标。
slicing ::= primary '[' slice_list ']'
slice_list ::= slice_item (',' slice_item)* [',']
slice_item ::= expression | proper_slice
proper_slice ::= [expression] ':' [expression] [':' [expression] ]
切片列表中具有多个切片项的切片只能与求值为 Tensor
类型的对象的基本项一起使用。
调用#
call ::= primary '(' argument_list ')'
argument_list ::= args [',' kwargs] | kwargs
args ::= [arg (',' arg)*]
kwargs ::= [kwarg (',' kwarg)*]
kwarg ::= arg '=' expression
arg ::= identifier
基本项必须解开或求值为可调用对象。在尝试调用之前,所有参数表达式都会被求值。
幂运算符#
power ::= primary ['**' u_expr]
幂运算符具有与内置 pow 函数(不支持)相同的语义;它计算其左参数的右参数次方。它的结合性比左侧的一元运算符强,但比右侧的一元运算符弱;即 -2 ** -3 == -(2 ** (-3))
。左操作数和右操作数可以是 int
、float
或 Tensor
。在标量-张量/张量-标量幂运算的情况下,标量会被广播,而张量-张量幂运算是逐元素进行的,没有任何广播。
一元和算术按位运算#
u_expr ::= power | '-' power | '~' power
一元 -
运算符产生其参数的否定。一元 ~
运算符产生其参数的按位求反。-
可与 int
、float
和 int
和 float
类型的 Tensor
一起使用。~
只能与 int
和 int
类型的 Tensor
一起使用。
二元算术运算#
m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr
a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr
二元算术运算符可以操作 Tensor
、int
和 float
。对于张量-张量运算,两个参数必须具有相同的形状。对于标量-张量或张量-标量运算,标量通常被广播到张量的尺寸。除法运算只能接受标量作为其右侧参数,并且不支持广播。@
运算符用于矩阵乘法,并且只操作 Tensor
参数。乘法运算符(*
)可以与列表和整数一起使用,以获得结果,该结果是原始列表重复一定次数。
移位运算#
shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr
这些运算符接受两个 int
参数,两个 Tensor
参数,或一个 Tensor
参数和一个 int
或 float
参数。在所有情况下,右移 n
定义为地板除以 pow(2, n)
,左移 n
定义为乘以 pow(2, n)
。当两个参数都是 Tensor
时,它们必须具有相同的形状。当一个参数是标量而另一个参数是 Tensor
时,标量会逻辑上广播以匹配 Tensor
的大小。
二进制按位运算#
and_expr ::= shift_expr | and_expr '&' shift_expr
xor_expr ::= and_expr | xor_expr '^' and_expr
or_expr ::= xor_expr | or_expr '|' xor_expr
运算符 &
计算其参数的按位 AND,^
计算按位 XOR,|
计算按位 OR。两个操作数都必须是 int
或 Tensor
,或者左操作数必须是 Tensor
而右操作数必须是 int
。当两个操作数都是 Tensor
时,它们必须具有相同的形状。当右操作数是 int
,而左操作数是 Tensor
时,右操作数会逻辑上广播以匹配 Tensor
的形状。
比较#
comparison ::= or_expr (comp_operator or_expr)*
comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in'
比较会产生布尔值(True
或 False
),或者如果其中一个操作数是 Tensor
,则产生布尔 Tensor
。只要比较不产生多于一个元素的布尔 Tensor
,就可以任意地链接比较。 a op1 b op2 c ...
等价于 a op1 b and b op2 c and ...
。
值比较#
运算符 <
、>
、==
、>=
、<=
和 !=
比较两个对象的数值。这两个对象通常需要是相同的类型,除非对象之间存在隐式类型转换。如果用户自定义类型定义了丰富的比较方法(例如 __lt__
),则可以比较它们。内置类型比较就像 Python 一样。
数字进行数学比较。
字符串进行字典序比较。
lists
、tuples
和dicts
只能与相同类型并具有相同形状的其他lists
、tuples
和dicts
进行比较,并使用相应元素的比较运算符进行比较。
成员测试操作#
运算符 in
和 not in
进行成员测试。x in s
求值为 True
,如果 x
是 s
的成员,否则为 False
。x not in s
等价于 not x in s
。此运算符支持 lists
、dicts
和 tuples
,如果用户自定义类型实现了 __contains__
方法,也可以使用它们。
同一性比较#
对于 int
、double
、bool
和 torch.device
以外的所有类型,运算符 is
和 is not
测试对象同一性;x is y
当且仅当 x
和 y
是同一个对象时为 True
。对于所有其他类型,is
等价于使用 ==
进行比较。x is not y
返回 x is y
的反值。
布尔运算#
or_test ::= and_test | or_test 'or' and_test
and_test ::= not_test | and_test 'and' not_test
not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test
用户自定义对象可以通过实现 __bool__
方法来定制其转换为 bool
。运算符 not
在其操作数为 false 时返回 True
,否则返回 False
。表达式 x
and y
首先计算 x
;如果 x
为 False
,则返回其值(False
);否则,计算 y
并返回其值(False
或 True
)。表达式 x
or y
首先计算 x
;如果 x
为 True
,则返回其值(True
);否则,计算 y
并返回其值(False
或 True
)。
条件表达式#
conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression]
expression ::= conditional_expression
表达式 x if c else y
首先计算条件 c
而不是 x。如果 c
为 True
,则计算 x
并返回其值;否则,计算 y
并返回其值。与 if 语句一样,x
和 y
必须计算为相同类型的值。
表达式列表#
expression_list ::= expression (',' expression)* [',']
starred_item ::= '*' primary
星号表达式只能出现在赋值语句的左侧,例如 a, *b, c = ...
。
简单语句#
以下部分描述了 TorchScript 支持的简单语句的语法。它模仿了 Python 语言参考的简单语句章节。
表达式语句#
expression_stmt ::= starred_expression
starred_expression ::= expression | (starred_item ",")* [starred_item]
starred_item ::= assignment_expression | "*" or_expr
赋值语句#
assignment_stmt ::= (target_list "=")+ (starred_expression)
target_list ::= target ("," target)* [","]
target ::= identifier
| "(" [target_list] ")"
| "[" [target_list] "]"
| attributeref
| subscription
| slicing
| "*" target
增强赋值语句#
augmented_assignment_stmt ::= augtarget augop (expression_list)
augtarget ::= identifier | attributeref | subscription
augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" |
"**="| ">>=" | "<<=" | "&=" | "^=" | "|="
注解赋值语句#
annotated_assignment_stmt ::= augtarget ":" expression
["=" (starred_expression)]
raise
语句#
raise_stmt ::= "raise" [expression ["from" expression]]
TorchScript 中的 Raise 语句不支持 try\except\finally
。
assert
语句#
assert_stmt ::= "assert" expression ["," expression]
TorchScript 中的 Assert 语句不支持 try\except\finally
。
return
语句#
return_stmt ::= "return" [expression_list]
TorchScript 中的 Return 语句不支持 try\except\finally
。
del
语句#
del_stmt ::= "del" target_list
pass
语句#
pass_stmt ::= "pass"
print
语句#
print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")"
break
语句#
break_stmt ::= "break"
continue
语句:#
continue_stmt ::= "continue"
复合语句#
以下部分描述了 TorchScript 支持的复合语句的语法。该部分还强调了 Torchscript 与常规 Python 语句的区别。它模仿了 Python 语言参考的复合语句章节。
if
语句#
Torchscript 同时支持基本的 if/else
和三元 if/else
。
基本 if/else
语句#
if_stmt ::= "if" assignment_expression ":" suite
("elif" assignment_expression ":" suite)
["else" ":" suite]
elif
语句可以重复任意次数,但它必须在 else
语句之前。
三元 if/else
语句#
if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list]
示例 1
一个具有 1 个维度的 tensor
会被提升为 bool
。
import torch
@torch.jit.script
def fn(x: torch.Tensor):
if x: # The tensor gets promoted to bool
return True
return False
print(fn(torch.rand(1)))
上面的示例产生以下输出
True
示例 2
多维度的 tensor
不会被提升为 bool
。
import torch
# Multi dimensional Tensors error out.
@torch.jit.script
def fn():
if torch.rand(2):
print("Tensor is available")
if torch.rand(4,5,6):
print("Tensor is available")
print(fn())
运行上述代码会产生以下 RuntimeError
。
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
@torch.jit.script
def fn():
if torch.rand(2):
~~~~~~~~~~~~ <--- HERE
print("Tensor is available")
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
如果条件变量被注解为 final
,则会根据条件变量的求值结果来求值 true 或 false 分支。
示例 3
在此示例中,由于 a
被注解为 final
且被设置为 True
,因此只对 True 分支进行了求值。
import torch
a : torch.jit.final[Bool] = True
if a:
return torch.empty(2,3)
else:
return []
while
语句#
while_stmt ::= "while" assignment_expression ":" suite
Torchscript 不支持 while...else
语句。这会导致 RuntimeError
。
for-in
语句#
for_stmt ::= "for" target_list "in" expression_list ":" suite
["else" ":" suite]
Torchscript 不支持 for...else
语句。这会导致 RuntimeError
。
示例 1
元组上的 for 循环:这些循环会展开循环,为元组的每个成员生成一个循环体。循环体必须正确地对每个成员进行类型检查。
import torch
from typing import Tuple
@torch.jit.script
def fn():
tup = (3, torch.ones(4))
for x in tup:
print(x)
fn()
上面的示例产生以下输出
3
1
1
1
1
[ CPUFloatType{4} ]
示例 2
列表上的 for 循环:对 nn.ModuleList
的 for 循环将在编译时展开循环体,并为模块列表的每个成员生成相应的循环体。
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2))
def forward(self, input):
return self.weight + input
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
def forward(self, v):
for module in self.mods:
v = module(v)
return v
model = torch.jit.script(MyModule())
with
语句#
with
语句用于将一个代码块的执行与由上下文管理器定义的方法包装起来。
with_stmt ::= "with" with_item ("," with_item) ":" suite
with_item ::= expression ["as" target]
如果一个目标包含在
with
语句中,上下文管理器的__enter__()
的返回值将赋给它。与 Python 不同,如果一个异常导致了 suite 的退出,其类型、值和 traceback 不会作为参数传递给__exit__()
。会提供三个None
参数。try
、except
和finally
语句在with
块内不支持。在
with
块内引发的异常不能被抑制。
tuple
语句#
tuple_stmt ::= tuple([iterables])
TorchScript 中的可迭代类型包括
Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和torch.nn.ModuleDict
。您不能通过此内置函数将 List 转换为 Tuple。
将所有输出解包到元组由以下部分覆盖
abc = func() # Function that returns a tuple
a,b = func()
getattr
语句#
getattr_stmt ::= getattr(object, name[, default])
属性名称必须是文字字符串。
不支持模块类型对象(例如,torch._C)。
不支持自定义类对象(例如,torch.classes.*)。
hasattr
语句#
hasattr_stmt ::= hasattr(object, name)
属性名称必须是文字字符串。
不支持模块类型对象(例如,torch._C)。
不支持自定义类对象(例如,torch.classes.*)。
zip
语句#
zip_stmt ::= zip(iterable1, iterable2)
参数必须是可迭代对象。
支持相同外部容器类型但长度不同的两个可迭代对象。
示例 1
可迭代对象必须是相同的容器类型。
a = [1, 2] # List
b = [2, 3, 4] # List
zip(a, b) # works
示例 2
此示例失败是因为可迭代对象的容器类型不同。
a = (1, 2) # Tuple
b = [2, 3, 4] # List
zip(a, b) # Runtime error
运行上述代码会产生以下 RuntimeError
。
RuntimeError: Can not iterate over a module list or
tuple with a value that does not have a statically determinable length.
示例 3
支持相同容器类型但数据类型不同的两个可迭代对象。
a = [1.3, 2.4]
b = [2, 3, 4]
zip(a, b) # Works
TorchScript 中的可迭代类型包括 Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和 torch.nn.ModuleDict
。
enumerate
语句#
enumerate_stmt ::= enumerate([iterable])
参数必须是可迭代对象。
TorchScript 中的可迭代类型包括
Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和torch.nn.ModuleDict
。
Python 值#
解析规则#
当给定一个 Python 值时,TorchScript 会尝试通过以下五种不同的方式来解析它。
- 可编译的 Python 实现
当一个 Python 值由 TorchScript 可以编译的 Python 实现支持时,TorchScript 会编译并使用底层的 Python 实现。
示例:
torch.jit.Attribute
- Op Python 包装器
当一个 Python 值是原生 PyTorch Op 的包装器时,TorchScript 会发出相应的运算符。
示例:
torch.jit._logging.add_stat_value
- Python 对象同一性匹配
对于 TorchScript 支持的一组有限的
torch.*
API 调用(以 Python 值形式),TorchScript 会尝试将 Python 值与集合中的每个项进行匹配。匹配成功时,TorchScript 会生成一个相应的
SugaredValue
实例,其中包含这些值的降低逻辑。示例:
torch.jit.isinstance()
- 名称匹配
对于 Python 内置函数和常量,TorchScript 会按名称识别它们,并创建一个相应的
SugaredValue
实例来实施其功能。示例:
all()
- 值快照
对于来自未知模块的 Python 值,TorchScript 会尝试获取该值的快照,并将其转换为被编译的函数或方法的图中的常量。
示例:
math.pi
Python 内置函数支持#
内置函数 |
支持级别 |
注意事项 |
---|---|---|
|
部分 |
仅支持 |
|
完全 |
|
|
完全 |
|
|
无 |
|
|
部分 |
仅支持 |
|
部分 |
仅支持 |
|
无 |
|
|
无 |
|
|
无 |
|
|
无 |
|
|
部分 |
仅支持 ASCII 字符集。 |
|
完全 |
|
|
无 |
|
|
无 |
|
|
无 |
|
|
完全 |
|
|
无 |
|
|
完全 |
|
|
完全 |
|
|
无 |
|
|
无 |
|
|
无 |
|
|
部分 |
不支持 |
|
部分 |
不支持手动指定索引。| 不支持格式类型修饰符。 |
|
无 |
|
|
部分 |
属性名称必须是字符串字面量。 |
|
无 |
|
|
部分 |
属性名称必须是字符串字面量。 |
|
完全 |
|
|
部分 |
仅支持 |
|
完全 |
仅支持 |
|
无 |
|
|
部分 |
不支持 |
|
完全 |
|
|
无 |
|
|
无 |
|
|
完全 |
|
|
完全 |
|
|
部分 |
仅支持 ASCII 字符集。 |
|
完全 |
|
|
部分 |
不支持 |
|
无 |
|
|
完全 |
|
|
无 |
|
|
无 |
|
|
部分 |
不支持 |
|
无 |
|
|
无 |
|
|
完全 |
|
|
部分 |
不支持 |
|
完全 |
|
|
部分 |
不支持 |
|
完全 |
|
|
部分 |
它只能在 |
|
无 |
|
|
无 |
|
|
完全 |
|
|
无 |
torch.* API#
远程过程调用#
TorchScript 支持 RPC API 的一个子集,该子集支持在指定的远程工作节点上运行函数,而不是在本地运行。
具体来说,以下 API 完全支持:
torch.distributed.rpc.rpc_sync()
rpc_sync()
发起一个阻塞 RPC 调用,以在远程工作节点上运行一个函数。RPC 消息在 Python 代码执行的并行进行中发送和接收。有关其用法和示例的更多详细信息,请参阅
rpc_sync()
。
torch.distributed.rpc.rpc_async()
rpc_async()
发起一个非阻塞 RPC 调用,以在远程工作节点上运行一个函数。RPC 消息在 Python 代码执行的并行进行中发送和接收。有关其用法和示例的更多详细信息,请参阅
rpc_async()
。
torch.distributed.rpc.remote()
remote.()
在工作节点上执行远程调用,并返回一个远程引用RRef
作为返回值。有关其用法和示例的更多详细信息,请参阅
remote()
。
异步执行#
TorchScript 使您能够创建异步计算任务,以更好地利用计算资源。这是通过支持一组仅在 TorchScript 中可用的 API 来实现的。
类型注解#
TorchScript 是静态类型的。它提供并支持一组实用程序来帮助注解变量和属性。
torch.jit.annotate()
为 TorchScript 提供类型提示,其中 Python 3 风格的类型提示效果不佳。
一个常见的例子是为
[]
等表达式注解类型。[]
默认被视为List[torch.Tensor]
。当需要不同类型时,可以使用此代码向 TorchScript 提供提示:torch.jit.annotate(List[int], [])
。更多详细信息可以在
annotate()
中找到。
torch.jit.Attribute
常见的用例包括为
torch.nn.Module
属性提供类型提示。由于它们的__init__
方法未被 TorchScript 解析,因此在模块的__init__
方法中应使用torch.jit.Attribute
而不是torch.jit.annotate
。更多详细信息可以在
Attribute()
中找到。
torch.jit.Final
Python 的
typing.Final
的别名。torch.jit.Final
仅为保持向后兼容性而保留。
元编程#
TorchScript 提供了一套实用程序来促进元编程。
torch.jit.is_scripting()
返回一个布尔值,指示当前程序是否由
torch.jit.script
编译。当在
assert
或if
语句中使用时,torch.jit.is_scripting()
求值为False
的作用域或分支不会被编译。其值可以在编译时静态求值,因此常用于
if
语句以阻止 TorchScript 编译其中一个分支。更多详细信息和示例可以在
is_scripting()
中找到。
torch.jit.is_tracing()
返回一个布尔值,指示当前程序是否由
torch.jit.trace
/torch.jit.trace_module
跟踪。更多详细信息可以在
is_tracing()
中找到。
@torch.jit.ignore
此装饰器指示编译器忽略一个函数或方法,并将其保留为 Python 函数。
这允许您在模型中保留尚未与 TorchScript 兼容的代码。
如果一个由
@torch.jit.ignore
装饰的函数从 TorchScript 调用,则被忽略的函数会将调用分派到 Python 解释器。带有被忽略函数的模型无法导出。
更多详细信息和示例可以在
ignore()
中找到。
@torch.jit.unused
此装饰器指示编译器忽略一个函数或方法,并将其替换为引发异常。
这允许您在模型中保留尚未与 TorchScript 兼容的代码,并且仍然可以导出模型。
如果一个由
@torch.jit.unused
装饰的函数从 TorchScript 调用,则会引发运行时错误。更多详细信息和示例可以在
unused()
中找到。
类型细化#
torch.jit.isinstance()
返回一个布尔值,指示一个变量是否是指定类型。
有关其用法和示例的更多详细信息,请参阅
isinstance()
。