TorchScript Language Reference#
Created On: Dec 18, 2019 | Last Updated On: Jun 13, 2025
TorchScript is a statically typed subset of Python that can either be written directly (using the @torch.jit.script
decorator) or generated automatically from Python code via tracing. When using tracing, code is automatically converted into this subset of Python by recording only the actual operators on tensors and simply executing and discarding the other surrounding Python code.
When writing TorchScript directly using @torch.jit.script
decorator, the programmer must only use the subset of Python supported in TorchScript. This section documents what is supported in TorchScript as if it were a language reference for a stand alone language. Any features of Python not mentioned in this reference are not part of TorchScript. See Builtin Functions
for a complete reference of available PyTorch tensor methods, modules, and functions.
As a subset of Python, any valid TorchScript function is also a valid Python function. This makes it possible to disable TorchScript
and debug the function using standard Python tools like pdb
. The reverse is not true: there are many valid Python programs that are not valid TorchScript programs. Instead, TorchScript focuses specifically on the features of Python that are needed to represent neural network models in PyTorch.
Types#
The largest difference between TorchScript and the full Python language is that TorchScript only supports a small set of types that are needed to express neural net models. In particular, TorchScript supports
类型 |
描述 |
---|---|
|
A PyTorch tensor of any dtype, dimension, or backend |
|
A tuple containing subtypes |
|
A boolean value |
|
A scalar integer |
|
A scalar floating point number |
|
A string |
|
A list of which all members are type |
|
A value which is either None or type |
|
A dict with key type |
|
A {ref}`TorchScript Class` |
|
A {ref}`TorchScript Enum` |
|
A |
|
One of the subtypes |
Unlike Python, each variable in TorchScript function must have a single static type. This makes it easier to optimize TorchScript functions.
Example (a type mismatch)
import torch
@torch.jit.script
def an_error(x):
if x:
r = torch.rand(1)
else:
r = 4
return r
Traceback (most recent call last):
...
RuntimeError: ...
Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
if x:
~~~~~
r = torch.rand(1)
~~~~~~~~~~~~~~~~~
else:
~~~~~
r = 4
~~~~~ <--- HERE
return r
and was used here:
else:
r = 4
return r
~ <--- HERE...
Unsupported Typing Constructs#
TorchScript does not support all features and types of the typing
module. Some of these are more fundamental things that are unlikely to be added in the future while others may be added if there is enough user demand to make it a priority.
These types and features from the typing
module are unavailable in TorchScript.
Item |
描述 |
---|---|
|
|
Not implemented |
|
Not implemented |
|
Not implemented |
|
Not implemented |
|
Not implemented |
|
This is supported for module attributes class attribute annotations but not for functions |
|
TorchScript does not support |
|
|
|
Type aliases |
Not implemented |
Nominal vs structural subtyping |
Nominal typing is in development, but structural typing is not |
NewType |
Unlikely to be implemented |
Generics |
Unlikely to be implemented |
Any other functionality from the typing
module not explicitly listed in this documentation is unsupported.
Default Types#
By default, all parameters to a TorchScript function are assumed to be Tensor. To specify that an argument to a TorchScript function is another type, it is possible to use MyPy-style type annotations using the types listed above.
import torch
@torch.jit.script
def foo(x, tup):
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
注意
It is also possible to annotate types with Python 3 type hints from the typing
module.
import torch
from typing import Tuple
@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
An empty list is assumed to be List[Tensor]
and empty dicts Dict[str, Tensor]
. To instantiate an empty list or dict of other types, use Python 3 type hints
.
Example (type annotations for Python 3)
import torch
import torch.nn as nn
from typing import Dict, List, Tuple
class EmptyDataStructures(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
# This annotates the list to be a `List[Tuple[int, float]]`
my_list: List[Tuple[int, float]] = []
for i in range(10):
my_list.append((i, x.item()))
my_dict: Dict[str, int] = {}
return my_list, my_dict
x = torch.jit.script(EmptyDataStructures())
Optional Type Refinement#
TorchScript will refine the type of a variable of type Optional[T]
when a comparison to None
is made inside the conditional of an if-statement or checked in an assert
. The compiler can reason about multiple None
checks that are combined with and
, or
, and not
. Refinement will also occur for else blocks of if-statements that are not explicitly written.
The None
check must be within the if-statement’s condition; assigning a None
check to a variable and using it in the if-statement’s condition will not refine the types of variables in the check. Only local variables will be refined, an attribute like self.x
will not and must assigned to a local variable to be refined.
Example (refining types on parameters and locals)
import torch
import torch.nn as nn
from typing import Optional
class M(nn.Module):
z: Optional[int]
def __init__(self, z):
super().__init__()
# If `z` is None, its type cannot be inferred, so it must
# be specified (above)
self.z = z
def forward(self, x, y, z):
# type: (Optional[int], Optional[int], Optional[int]) -> int
if x is None:
x = 1
x = x + 1
# Refinement for an attribute by assigning it to a local
z = self.z
if y is not None and z is not None:
x = y + z
# Refinement via an `assert`
assert z is not None
x += z
return x
module = torch.jit.script(M(2))
module = torch.jit.script(M(None))
TorchScript Classes#
警告
TorchScript class support is experimental. Currently it is best suited for simple record-like types (think a NamedTuple
with methods attached).
Python classes can be used in TorchScript if they are annotated with @torch.jit.script
, similar to how you would declare a TorchScript function
@torch.jit.script
class Foo:
def __init__(self, x, y):
self.x = x
def aug_add_x(self, inc):
self.x += inc
This subset is restricted
All functions must be valid TorchScript functions (including
__init__()
).Classes must be new-style classes, as we use
__new__()
to construct them with pybind11.TorchScript classes are statically typed. Members can only be declared by assigning to self in the
__init__()
method.For example, assigning to
self
outside of the__init__()
method@torch.jit.script class Foo: def assign_x(self): self.x = torch.rand(2, 3)
Will result in
RuntimeError: Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: def assign_x(self): self.x = torch.rand(2, 3) ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
No expressions except method definitions are allowed in the body of the class.
No support for inheritance or any other polymorphism strategy, except for inheriting from
object
to specify a new-style class.
定义类之后,它可以在 TorchScript 和 Python 中像其他 TorchScript 类型一样进行互换使用。
# Declare a TorchScript class
@torch.jit.script
class Pair:
def __init__(self, first, second):
self.first = first
self.second = second
@torch.jit.script
def sum_pair(p):
# type: (Pair) -> Tensor
return p.first + p.second
p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))
TorchScript 枚举#
Python 枚举可以在 TorchScript 中使用,无需任何额外的注解或代码。
from enum import Enum
class Color(Enum):
RED = 1
GREEN = 2
@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
定义枚举后,它可以在 TorchScript 和 Python 中像其他 TorchScript 类型一样进行互换使用。枚举值的类型必须是 int
、float
或 str
。所有值必须是同一类型;不支持枚举值的异构类型。
命名元组#
由 collections.namedtuple
生成的类型可以在 TorchScript 中使用。
import torch
import collections
Point = collections.namedtuple('Point', ['x', 'y'])
@torch.jit.script
def total(point):
# type: (Point) -> Tensor
return point.x + point.y
p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))
可迭代对象#
某些函数(例如 zip
和 enumerate
)只能操作可迭代类型。TorchScript 中的可迭代类型包括 Tensor
、列表、元组、字典、字符串、torch.nn.ModuleList
和 torch.nn.ModuleDict
。
表达式#
支持以下 Python 表达式。
字面量#
True
False
None
'string literals'
"string literals"
3 # interpreted as int
3.4 # interpreted as a float
列表构建#
空列表被假定为 List[Tensor]
类型。其他列表字面量的类型从其成员的类型推断。有关更多详细信息,请参见 [默认类型]。
[3, 4]
[]
[torch.rand(3), torch.rand(4)]
元组构建#
(3, 4)
(3,)
字典构建#
空字典被假定为 Dict[str, Tensor]
类型。其他字典字面量的类型从其成员的类型推断。有关更多详细信息,请参见 [默认类型]。
{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}
算术运算符#
a + b
a - b
a * b
a / b
a ^ b
a @ b
比较运算符#
a == b
a != b
a < b
a > b
a <= b
a >= b
逻辑运算符#
a and b
a or b
not b
下标和切片#
t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]
函数调用#
对 内置函数
的调用
torch.rand(3, dtype=torch.int)
对其他脚本函数的调用
import torch
@torch.jit.script
def foo(x):
return x + 1
@torch.jit.script
def bar(x):
return foo(x)
方法调用#
对张量等内置类型方法的调用:x.mm(y)
在模块上,方法在调用前必须经过编译。TorchScript 编译器在编译其他方法时会递归地编译它看到的方法。默认情况下,编译从 forward
方法开始。由 forward
调用的任何方法,以及由这些方法调用的任何方法,依此类推,都将被编译。要从 forward
以外的方法开始编译,请使用 @torch.jit.export
装饰器(forward
会被隐式标记为 @torch.jit.export
)。
直接调用子模块(例如 self.resnet(input)
)等同于调用其 forward
方法(例如 self.resnet.forward(input)
)。
import torch
import torch.nn as nn
import torchvision
class MyModule(nn.Module):
def __init__(self):
super().__init__()
means = torch.tensor([103.939, 116.779, 123.68])
self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
resnet = torchvision.models.resnet18()
self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))
def helper(self, input):
return self.resnet(input - self.means)
def forward(self, input):
return self.helper(input)
# Since nothing in the model calls `top_level_method`, the compiler
# must be explicitly told to compile this method
@torch.jit.export
def top_level_method(self, input):
return self.other_helper(input)
def other_helper(self, input):
return input + 10
# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())
三元表达式#
x if x > y else y
类型转换#
float(ten)
int(3.5)
bool(ten)
str(2)``
访问模块参数#
self.my_parameter
self.my_submodule.my_parameter
语句#
TorchScript 支持以下类型的语句:
简单赋值#
a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b
打印语句#
print("the result of an add:", a + b)
If 语句#
if a < 4:
r = -a
elif a < 3:
r = a + a
else:
r = 3 * a
除了布尔值、浮点数、整数和张量之外,它们也可以用在条件语句中,并且会被隐式转换为布尔值。
While 循环#
a = 0
while a < 4:
print(a)
a += 1
带 range 的 for 循环#
x = 0
for i in range(10):
x *= i
遍历元组的 for 循环#
这些循环会展开,为元组的每个成员生成一个循环体。循环体必须正确地进行类型检查。
tup = (3, torch.rand(4))
for x in tup:
print(x)
遍历常量 nn.ModuleList 的 for 循环#
要在编译后的方法中使用 nn.ModuleList
,必须通过将属性名称添加到该类型的 __constants__
列表中来将其标记为常量。常量模块列表的 for 循环将在编译时展开循环体,针对每个常量模块列表的成员。
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2))
def forward(self, input):
return self.weight + input
class MyModule(torch.nn.Module):
__constants__ = ['mods']
def __init__(self):
super().__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
def forward(self, v):
for module in self.mods:
v = module(v)
return v
m = torch.jit.script(MyModule())
Break 和 Continue#
for i in range(5):
if i == 1:
continue
if i == 3:
break
print(i)
Return#
return a, b
变量解析#
TorchScript 支持 Python 变量解析(即作用域)规则的子集。局部变量的行为与 Python 中的相同,但有一个限制,即变量在函数的所有路径中必须具有相同的类型。如果在 if 语句的不同分支中变量具有不同类型,那么在 if 语句之后使用它将是错误。
同样,如果一个变量仅在函数的部分路径中定义,则不允许使用它。
示例
@torch.jit.script
def foo(x):
if x < 0:
y = 4
print(y)
Traceback (most recent call last):
...
RuntimeError: ...
y is not defined in the false branch...
@torch.jit.script...
def foo(x):
if x < 0:
~~~~~~~~~
y = 4
~~~~~ <--- HERE
print(y)
and was used here:
if x < 0:
y = 4
print(y)
~ <--- HERE...
非局部变量在函数定义时被解析为 Python 值。然后,这些值将使用 [Python 值的使用] 部分所述的规则转换为 TorchScript 值。
Python 值的使用#
为了方便编写 TorchScript,我们允许脚本代码引用周围作用域中的 Python 值。例如,任何对 torch
的引用,当函数声明时,TorchScript 编译器实际上会将其解析为 torch
Python 模块。这些 Python 值不是 TorchScript 的一等公民。相反,它们在编译时被解语法糖为 TorchScript 支持的原生类型。这取决于编译时引用的 Python 值的动态类型。本节介绍访问 TorchScript 中 Python 值时使用的规则。
函数#
TorchScript 可以调用 Python 函数。这项功能在逐步将模型转换为 TorchScript 时非常有用。模型可以逐个函数地迁移到 TorchScript,并保留对 Python 函数的调用。这样,您可以一边转换一边逐步检查模型的正确性。
- torch.jit.is_scripting()[source]#
当处于编译状态时返回 True,否则返回 False。这对于使用 @unused 装饰器来保留模型中尚未兼容 TorchScript 的代码特别有用。.. testcode
import torch @torch.jit.unused def unsupported_linear_op(x): return x def linear(x): if torch.jit.is_scripting(): return torch.linear(x) else: return unsupported_linear_op(x)
- 返回类型
Python 模块上的属性查找#
TorchScript 可以查找模块上的属性。诸如 torch.add
之类的 内置函数
就是这样访问的。这使得 TorchScript 能够调用其他模块中定义的函数。
Python 定义的常量#
TorchScript 还提供了一种使用 Python 中定义的常量的方法。这些常量可用于将超参数硬编码到函数中,或定义通用常量。有两种方法可以指定将 Python 值视为常量。
作为模块属性查找的值被假定为常量。
import math
import torch
@torch.jit.script
def fn():
return math.pi
可以通过使用
Final[T]
注解来将 ScriptModule 的属性标记为常量。
import torch
import torch.nn as nn
class Foo(nn.Module):
# `Final` from the `typing_extensions` module can also be used
a : torch.jit.Final[int]
def __init__(self):
super().__init__()
self.a = 1 + 4
def forward(self, input):
return self.a + input
f = torch.jit.script(Foo())
支持的常量 Python 类型有:
int
浮点数
布尔值
torch.device
torch.layout
torch.dtype
包含支持类型的元组
torch.nn.ModuleList
,可以在 TorchScript for 循环中使用
模块属性#
torch.nn.Parameter
包装器和 register_buffer
可用于将张量分配给模块。分配给已编译模块的其他值,如果它们的类型可以被推断,也会被添加到已编译模块中。TorchScript 中可用的所有 [类型] 都可以用作模块属性。张量属性在语义上与缓冲区相同。空列表、字典和 None
值的类型无法推断,必须通过 PEP 526 风格的类注解来指定。如果类型无法推断且未明确注解,则不会将其添加为结果 ScriptModule
的属性。
示例
from typing import List, Dict
class Foo(nn.Module):
# `words` is initialized as an empty list, so its type must be specified
words: List[str]
# The type could potentially be inferred if `a_dict` (below) was not
# empty, but this annotation ensures `some_dict` will be made into the
# proper type
some_dict: Dict[str, int]
def __init__(self, a_dict):
super().__init__()
self.words = []
self.some_dict = a_dict
# `int`s can be inferred
self.my_int = 10
def forward(self, input):
# type: (str) -> int
self.words.append(input)
return self.some_dict[input] + self.my_int
f = torch.jit.script(Foo({'hi': 2}))