评价此页

TorchScript Language Reference#

Created On: Dec 18, 2019 | Last Updated On: Jun 13, 2025

TorchScript is a statically typed subset of Python that can either be written directly (using the @torch.jit.script decorator) or generated automatically from Python code via tracing. When using tracing, code is automatically converted into this subset of Python by recording only the actual operators on tensors and simply executing and discarding the other surrounding Python code.

When writing TorchScript directly using @torch.jit.script decorator, the programmer must only use the subset of Python supported in TorchScript. This section documents what is supported in TorchScript as if it were a language reference for a stand alone language. Any features of Python not mentioned in this reference are not part of TorchScript. See Builtin Functions for a complete reference of available PyTorch tensor methods, modules, and functions.

As a subset of Python, any valid TorchScript function is also a valid Python function. This makes it possible to disable TorchScript and debug the function using standard Python tools like pdb. The reverse is not true: there are many valid Python programs that are not valid TorchScript programs. Instead, TorchScript focuses specifically on the features of Python that are needed to represent neural network models in PyTorch.

Types#

The largest difference between TorchScript and the full Python language is that TorchScript only supports a small set of types that are needed to express neural net models. In particular, TorchScript supports

类型

描述

张量

A PyTorch tensor of any dtype, dimension, or backend

Tuple[T0, T1, ..., TN]

A tuple containing subtypes T0, T1, etc. (e.g. Tuple[Tensor, Tensor])

布尔值

A boolean value

int

A scalar integer

浮点数

A scalar floating point number

str

A string

List[T]

A list of which all members are type T

Optional[T]

A value which is either None or type T

Dict[K, V]

A dict with key type K and value type V. Only str, int, and float are allowed as key types.

T

A {ref}`TorchScript Class`

E

A {ref}`TorchScript Enum`

NamedTuple[T0, T1, ...]

A collections.namedtuple tuple type

Union[T0, T1, ...]

One of the subtypes T0, T1, etc.

Unlike Python, each variable in TorchScript function must have a single static type. This makes it easier to optimize TorchScript functions.

Example (a type mismatch)

import torch

@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r
Traceback (most recent call last):
  ...
RuntimeError: ...

Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~
        r = torch.rand(1)
        ~~~~~~~~~~~~~~~~~
    else:
    ~~~~~
        r = 4
        ~~~~~ <--- HERE
    return r
and was used here:
    else:
        r = 4
    return r
           ~ <--- HERE...

Unsupported Typing Constructs#

TorchScript does not support all features and types of the typing module. Some of these are more fundamental things that are unlikely to be added in the future while others may be added if there is enough user demand to make it a priority.

These types and features from the typing module are unavailable in TorchScript.

Item

描述

typing.Any

typing.Any is currently in development but not yet released

typing.NoReturn

Not implemented

typing.Sequence

Not implemented

typing.Callable

Not implemented

typing.Literal

Not implemented

typing.ClassVar

Not implemented

typing.Final

This is supported for module attributes class attribute annotations but not for functions

typing.AnyStr

TorchScript does not support bytes so this type is not used

typing.overload

typing.overload is currently in development but not yet released

Type aliases

Not implemented

Nominal vs structural subtyping

Nominal typing is in development, but structural typing is not

NewType

Unlikely to be implemented

Generics

Unlikely to be implemented

Any other functionality from the typing module not explicitly listed in this documentation is unsupported.

Default Types#

By default, all parameters to a TorchScript function are assumed to be Tensor. To specify that an argument to a TorchScript function is another type, it is possible to use MyPy-style type annotations using the types listed above.

import torch

@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

注意

It is also possible to annotate types with Python 3 type hints from the typing module.

import torch
from typing import Tuple

@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

An empty list is assumed to be List[Tensor] and empty dicts Dict[str, Tensor]. To instantiate an empty list or dict of other types, use Python 3 type hints.

Example (type annotations for Python 3)

import torch
import torch.nn as nn
from typing import Dict, List, Tuple

class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))

        my_dict: Dict[str, int] = {}
        return my_list, my_dict

x = torch.jit.script(EmptyDataStructures())

Optional Type Refinement#

TorchScript will refine the type of a variable of type Optional[T] when a comparison to None is made inside the conditional of an if-statement or checked in an assert. The compiler can reason about multiple None checks that are combined with and, or, and not. Refinement will also occur for else blocks of if-statements that are not explicitly written.

The None check must be within the if-statement’s condition; assigning a None check to a variable and using it in the if-statement’s condition will not refine the types of variables in the check. Only local variables will be refined, an attribute like self.x will not and must assigned to a local variable to be refined.

Example (refining types on parameters and locals)

import torch
import torch.nn as nn
from typing import Optional

class M(nn.Module):
    z: Optional[int]

    def __init__(self, z):
        super().__init__()
        # If `z` is None, its type cannot be inferred, so it must
        # be specified (above)
        self.z = z

    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1

        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z

        # Refinement via an `assert`
        assert z is not None
        x += z
        return x

module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

TorchScript Classes#

警告

TorchScript class support is experimental. Currently it is best suited for simple record-like types (think a NamedTuple with methods attached).

Python classes can be used in TorchScript if they are annotated with @torch.jit.script, similar to how you would declare a TorchScript function

@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x

  def aug_add_x(self, inc):
    self.x += inc

This subset is restricted

  • All functions must be valid TorchScript functions (including __init__()).

  • Classes must be new-style classes, as we use __new__() to construct them with pybind11.

  • TorchScript classes are statically typed. Members can only be declared by assigning to self in the __init__() method.

    For example, assigning to self outside of the __init__() method

    @torch.jit.script
    class Foo:
      def assign_x(self):
        self.x = torch.rand(2, 3)
    

    Will result in

    RuntimeError:
    Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
    def assign_x(self):
      self.x = torch.rand(2, 3)
      ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    
  • No expressions except method definitions are allowed in the body of the class.

  • No support for inheritance or any other polymorphism strategy, except for inheriting from object to specify a new-style class.

定义类之后,它可以在 TorchScript 和 Python 中像其他 TorchScript 类型一样进行互换使用。

# Declare a TorchScript class
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second

@torch.jit.script
def sum_pair(p):
  # type: (Pair) -> Tensor
  return p.first + p.second

p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))

TorchScript 枚举#

Python 枚举可以在 TorchScript 中使用,无需任何额外的注解或代码。

from enum import Enum


class Color(Enum):
    RED = 1
    GREEN = 2

@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
    if x == Color.RED:
        return True

    return x == y

定义枚举后,它可以在 TorchScript 和 Python 中像其他 TorchScript 类型一样进行互换使用。枚举值的类型必须是 intfloatstr。所有值必须是同一类型;不支持枚举值的异构类型。

命名元组#

collections.namedtuple 生成的类型可以在 TorchScript 中使用。

import torch
import collections

Point = collections.namedtuple('Point', ['x', 'y'])

@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y

p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))

可迭代对象#

某些函数(例如 zipenumerate)只能操作可迭代类型。TorchScript 中的可迭代类型包括 Tensor、列表、元组、字典、字符串、torch.nn.ModuleListtorch.nn.ModuleDict

表达式#

支持以下 Python 表达式。

字面量#

True
False
None
'string literals'
"string literals"
3  # interpreted as int
3.4  # interpreted as a float

列表构建#

空列表被假定为 List[Tensor] 类型。其他列表字面量的类型从其成员的类型推断。有关更多详细信息,请参见 [默认类型]。

[3, 4]
[]
[torch.rand(3), torch.rand(4)]

元组构建#

(3, 4)
(3,)

字典构建#

空字典被假定为 Dict[str, Tensor] 类型。其他字典字面量的类型从其成员的类型推断。有关更多详细信息,请参见 [默认类型]。

{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}

变量#

有关变量解析,请参见 [变量解析]。

my_variable_name

算术运算符#

a + b
a - b
a * b
a / b
a ^ b
a @ b

比较运算符#

a == b
a != b
a < b
a > b
a <= b
a >= b

逻辑运算符#

a and b
a or b
not b

下标和切片#

t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]

函数调用#

内置函数 的调用

torch.rand(3, dtype=torch.int)

对其他脚本函数的调用

import torch

@torch.jit.script
def foo(x):
    return x + 1

@torch.jit.script
def bar(x):
    return foo(x)

方法调用#

对张量等内置类型方法的调用:x.mm(y)

在模块上,方法在调用前必须经过编译。TorchScript 编译器在编译其他方法时会递归地编译它看到的方法。默认情况下,编译从 forward 方法开始。由 forward 调用的任何方法,以及由这些方法调用的任何方法,依此类推,都将被编译。要从 forward 以外的方法开始编译,请使用 @torch.jit.export 装饰器(forward 会被隐式标记为 @torch.jit.export)。

直接调用子模块(例如 self.resnet(input))等同于调用其 forward 方法(例如 self.resnet.forward(input))。

import torch
import torch.nn as nn
import torchvision

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))

    def helper(self, input):
        return self.resnet(input - self.means)

    def forward(self, input):
        return self.helper(input)

    # Since nothing in the model calls `top_level_method`, the compiler
    # must be explicitly told to compile this method
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)

    def other_helper(self, input):
        return input + 10

# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())

三元表达式#

x if x > y else y

类型转换#

float(ten)
int(3.5)
bool(ten)
str(2)``

访问模块参数#

self.my_parameter
self.my_submodule.my_parameter

语句#

TorchScript 支持以下类型的语句:

简单赋值#

a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

模式匹配赋值#

a, b = tuple_or_list
a, b, *c = a_tuple

多重赋值

a = b, c = tup

If 语句#

if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a

除了布尔值、浮点数、整数和张量之外,它们也可以用在条件语句中,并且会被隐式转换为布尔值。

While 循环#

a = 0
while a < 4:
    print(a)
    a += 1

带 range 的 for 循环#

x = 0
for i in range(10):
    x *= i

遍历元组的 for 循环#

这些循环会展开,为元组的每个成员生成一个循环体。循环体必须正确地进行类型检查。

tup = (3, torch.rand(4))
for x in tup:
    print(x)

遍历常量 nn.ModuleList 的 for 循环#

要在编译后的方法中使用 nn.ModuleList,必须通过将属性名称添加到该类型的 __constants__ 列表中来将其标记为常量。常量模块列表的 for 循环将在编译时展开循环体,针对每个常量模块列表的成员。

class SubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(2))

    def forward(self, input):
        return self.weight + input

class MyModule(torch.nn.Module):
    __constants__ = ['mods']

    def __init__(self):
        super().__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v


m = torch.jit.script(MyModule())

Break 和 Continue#

for i in range(5):
    if i == 1:
        continue
    if i == 3:
        break
    print(i)

Return#

return a, b

变量解析#

TorchScript 支持 Python 变量解析(即作用域)规则的子集。局部变量的行为与 Python 中的相同,但有一个限制,即变量在函数的所有路径中必须具有相同的类型。如果在 if 语句的不同分支中变量具有不同类型,那么在 if 语句之后使用它将是错误。

同样,如果一个变量仅在函数的部分路径中定义,则不允许使用它。

示例

@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)
Traceback (most recent call last):
  ...
RuntimeError: ...

y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~
        y = 4
        ~~~~~ <--- HERE
    print(y)
and was used here:
    if x < 0:
        y = 4
    print(y)
          ~ <--- HERE...

非局部变量在函数定义时被解析为 Python 值。然后,这些值将使用 [Python 值的使用] 部分所述的规则转换为 TorchScript 值。

Python 值的使用#

为了方便编写 TorchScript,我们允许脚本代码引用周围作用域中的 Python 值。例如,任何对 torch 的引用,当函数声明时,TorchScript 编译器实际上会将其解析为 torch Python 模块。这些 Python 值不是 TorchScript 的一等公民。相反,它们在编译时被解语法糖为 TorchScript 支持的原生类型。这取决于编译时引用的 Python 值的动态类型。本节介绍访问 TorchScript 中 Python 值时使用的规则。

函数#

TorchScript 可以调用 Python 函数。这项功能在逐步将模型转换为 TorchScript 时非常有用。模型可以逐个函数地迁移到 TorchScript,并保留对 Python 函数的调用。这样,您可以一边转换一边逐步检查模型的正确性。

torch.jit.is_scripting()[source]#

当处于编译状态时返回 True,否则返回 False。这对于使用 @unused 装饰器来保留模型中尚未兼容 TorchScript 的代码特别有用。.. testcode

import torch

@torch.jit.unused
def unsupported_linear_op(x):
    return x

def linear(x):
    if torch.jit.is_scripting():
        return torch.linear(x)
    else:
        return unsupported_linear_op(x)
返回类型

布尔值

torch.jit.is_tracing()[source]#

返回一个布尔值。

在跟踪(如果函数在具有 torch.jit.trace 的代码跟踪期间被调用)时返回 True,否则返回 False

Python 模块上的属性查找#

TorchScript 可以查找模块上的属性。诸如 torch.add 之类的 内置函数 就是这样访问的。这使得 TorchScript 能够调用其他模块中定义的函数。

Python 定义的常量#

TorchScript 还提供了一种使用 Python 中定义的常量的方法。这些常量可用于将超参数硬编码到函数中,或定义通用常量。有两种方法可以指定将 Python 值视为常量。

  1. 作为模块属性查找的值被假定为常量。

import math
import torch

@torch.jit.script
def fn():
    return math.pi
  1. 可以通过使用 Final[T] 注解来将 ScriptModule 的属性标记为常量。

import torch
import torch.nn as nn

class Foo(nn.Module):
    # `Final` from the `typing_extensions` module can also be used
    a : torch.jit.Final[int]

    def __init__(self):
        super().__init__()
        self.a = 1 + 4

    def forward(self, input):
        return self.a + input

f = torch.jit.script(Foo())

支持的常量 Python 类型有:

  • int

  • 浮点数

  • 布尔值

  • torch.device

  • torch.layout

  • torch.dtype

  • 包含支持类型的元组

  • torch.nn.ModuleList,可以在 TorchScript for 循环中使用

模块属性#

torch.nn.Parameter 包装器和 register_buffer 可用于将张量分配给模块。分配给已编译模块的其他值,如果它们的类型可以被推断,也会被添加到已编译模块中。TorchScript 中可用的所有 [类型] 都可以用作模块属性。张量属性在语义上与缓冲区相同。空列表、字典和 None 值的类型无法推断,必须通过 PEP 526 风格的类注解来指定。如果类型无法推断且未明确注解,则不会将其添加为结果 ScriptModule 的属性。

示例

from typing import List, Dict

class Foo(nn.Module):
    # `words` is initialized as an empty list, so its type must be specified
    words: List[str]

    # The type could potentially be inferred if `a_dict` (below) was not
    # empty, but this annotation ensures `some_dict` will be made into the
    # proper type
    some_dict: Dict[str, int]

    def __init__(self, a_dict):
        super().__init__()
        self.words = []
        self.some_dict = a_dict

        # `int`s can be inferred
        self.my_int = 10

    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int

f = torch.jit.script(Foo({'hi': 2}))