注意
转到末尾 下载完整的示例代码。
通过示例学习混合前端语法#
创建于:2018 年 7 月 1 日 | 最后更新:2018 年 12 月 7 日 | 最后验证:未验证
作者: Nathan Inkawhich
本文旨在通过一个非代码密集型示例来突出混合前端的语法。混合前端是 Pytorch 1.0 的新亮点功能之一,它为开发人员提供了将模型从 Eager 模式 转换为 Graph 模式 的途径。PyTorch 用户非常熟悉 Eager 模式,因为它提供了我们作为研究人员都喜欢的易用性和灵活性。Caffe2 用户更熟悉 Graph 模式,它具有速度、优化机会和 C++ 运行时环境中的功能等优点。混合前端通过允许研究人员在 Eager 模式(即 PyTorch)中开发和完善模型,然后当速度和资源消耗变得至关重要时,逐步将经过验证的模型转换为 Graph 模式以用于生产,从而弥合了这两种模式之间的差距。
混合前端信息#
将模型转换为图模式的过程如下。首先,开发人员在 Eager 模式下构建、训练和测试模型。然后,他们使用即时 (JIT) 编译器逐步 跟踪 (trace) 和 脚本化 (script) 模型的每个函数/模块,每一步都验证输出是否正确。最后,当顶层模型的每个组件都被跟踪和脚本化后,模型本身就被跟踪了。此时模型已转换为图模式,并具有完整的无 Python 表示。通过这种表示,模型运行时可以利用高性能的 Caffe2 运算符和基于图的优化。
在继续之前,理解跟踪和脚本化的概念以及它们为何分离非常重要。跟踪 和 脚本化 的目标是相同的,那就是为给定函数中发生的操作创建图表示。差异来自 Eager 模式的灵活性,它允许模型架构中存在 数据依赖控制流。当函数没有数据依赖控制流时,可以使用 torch.jit.trace
进行 *跟踪*。但是,当函数 *有* 数据依赖控制流时,必须使用 torch.jit.script
进行 *脚本化*。我们将把混合前端的内部工作细节留待其他文档介绍,但下面的代码示例将展示如何跟踪和脚本化不同的纯 Python 函数和 Torch 模块的语法。希望您会发现使用混合前端是非侵入性的,因为它主要涉及向现有函数和类定义添加装饰器。
激励性示例#
在此示例中,我们将实现一个奇怪的数学函数,该函数在逻辑上可以分解为不包含和包含数据依赖控制流的四个部分。这里的目的是展示一个非代码密集型的示例,其中突出显示了 JIT 的使用。此示例是一个有用模型的代表,其实现已分为各种纯 Python 函数和模块。
我们希望实现的函数 \(Y(x)\),对于 \(x \epsilon \mathbb{N}\) 定义为
如前所述,计算分为四个部分。第一部分是 \(|2x|\) 的简单张量计算,可以被跟踪。第二部分是迭代乘积计算,它表示要被脚本化的数据相关控制流(循环迭代次数取决于运行时输入)。第三部分是可跟踪的 \(\lfloor \sqrt{a/5} \rfloor\) 计算。最后,第四部分根据 \(z(x)\) 的值处理输出情况,并且由于数据依赖性而必须被脚本化。现在,让我们看看这在代码中是什么样子。
第一部分 - 跟踪纯 Python 函数#
我们可以将第一部分实现为下面的纯 Python 函数。请注意,要跟踪此函数,我们调用 torch.jit.trace
并传入要跟踪的函数。由于跟踪需要预期运行时类型和形状的虚拟输入,我们还包括 torch.rand
来生成一个单值 torch 张量。
import torch
def fn(x):
return torch.abs(2*x)
# This is how you define a traced function
# Pass in both the function to be traced and an example input to ``torch.jit.trace``
traced_fn = torch.jit.trace(fn, torch.rand(()))
第二部分 - 脚本化纯 Python 函数#
我们还可以将第二部分实现为纯 Python 函数,其中我们迭代地计算乘积。由于迭代次数取决于输入的值,因此我们有数据相关的控制流,因此函数必须被脚本化。我们可以使用 @torch.jit.script
装饰器简单地脚本化 Python 函数。
# This is how you define a script function
# Apply this decorator directly to the function
@torch.jit.script
def script_fn(x):
z = torch.ones([1], dtype=torch.int64)
for i in range(int(x)):
z = z * (i + 1)
return z
第三部分 - 跟踪 nn.Module#
接下来,我们将在 torch.nn.Module
的 forward 函数中实现计算的第三部分。此模块可以被跟踪,但我们不会在这里添加装饰器,而是在模块构建时处理跟踪。因此,类定义完全没有改变。
# This is a normal module that can be traced.
class TracedModule(torch.nn.Module):
def forward(self, x):
x = x.type(torch.float32)
return torch.floor(torch.sqrt(x) / 5.)
第四部分 - 脚本化 nn.Module#
在计算的最后一部分,我们有一个必须被脚本化的 torch.nn.Module
。为了适应这一点,我们继承自 torch.jit.ScriptModule
,并向 forward 函数添加 @torch.jit.script_method
装饰器。
# This is how you define a scripted module.
# The module should inherit from ScriptModule and the forward should have the
# script_method decorator applied to it.
class ScriptModule(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
r = -x
if int(torch.fmod(x, 2.0)) == 0.0:
r = x / 2.0
return r
顶层模块#
现在我们将通过一个名为 Net
的顶层模块将计算的各个部分组合起来。在构造函数中,我们将实例化 TracedModule
和 ScriptModule
作为属性。这是必须的,因为我们最终想要跟踪/脚本化顶层模块,并且将跟踪/脚本化模块作为属性允许 Net 继承所需的子模块参数。请注意,这就是我们通过调用 torch.jit.trace()
并提供必要的虚拟输入来实际跟踪 TracedModule
的地方。还要注意,ScriptModule
是正常构造的,因为我们在类定义中处理了脚本化。
在这里我们还可以打印为计算的每个单独部分创建的图。打印出的图允许我们查看 JIT 最终如何将函数解释为图计算。
最后,我们为 Net 模块定义 forward
函数,其中我们将输入数据 x
运行通过计算的四个部分。这里没有奇怪的语法,我们像预期一样调用跟踪和脚本化模块和函数。
# This is a demonstration net that calls all of the different types of
# methods and functions
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
# Modules must be attributes on the Module because if you want to trace
# or script this Module, we must be able to inherit the submodules'
# params.
self.traced_module = torch.jit.trace(TracedModule(), torch.rand(()))
self.script_module = ScriptModule()
print('traced_fn graph', traced_fn.graph)
print('script_fn graph', script_fn.graph)
print('TracedModule graph', self.traced_module.__getattr__('forward').graph)
print('ScriptModule graph', self.script_module.__getattr__('forward').graph)
def forward(self, x):
# Call a traced function
x = traced_fn(x)
# Call a script function
x = script_fn(x)
# Call a traced submodule
x = self.traced_module(x)
# Call a scripted submodule
x = self.script_module(x)
return x
运行模型#
剩下要做的就是构建 Net 并通过 forward 函数计算输出。在这里,我们使用 \(x=5\) 作为测试输入值,并期望 \(Y(x)=190.\) 此外,请查看在 Net 构建期间打印的图。
# Instantiate this net and run it
n = Net()
print(n(torch.tensor([5]))) # 190.
traced_fn graph graph(%x : Float(requires_grad=0, device=cpu)):
%1 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:105:0
%2 : Float(requires_grad=0, device=cpu) = aten::mul(%x, %1) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:105:0
%3 : Float(requires_grad=0, device=cpu) = aten::abs(%2) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:105:0
return (%3)
script_fn graph graph(%x.1 : Tensor):
%13 : bool = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:127:4
%4 : NoneType = prim::Constant()
%3 : int = prim::Constant[value=4]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:126:30
%1 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:126:20
%2 : int[] = prim::ListConstruct(%1)
%z.1 : Tensor = aten::ones(%2, %3, %4, %4, %4) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:126:8
%10 : int = aten::Int(%x.1) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:127:19
%z : Tensor = prim::Loop(%10, %13, %z.1) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:127:4
block0(%i.1 : int, %z.11 : Tensor):
%17 : int = aten::add(%i.1, %1) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:128:17
%z.5 : Tensor = aten::mul(%z.11, %17) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:128:12
-> (%13, %z.5)
return (%z)
TracedModule graph graph(%self : __torch__.TracedModule,
%x.1 : Float(requires_grad=0, device=cpu)):
%4 : int = prim::Constant[value=6]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0
%5 : bool = prim::Constant[value=0]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0
%6 : bool = prim::Constant[value=0]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0
%7 : NoneType = prim::Constant()
%x : Float(requires_grad=0, device=cpu) = aten::to(%x.1, %4, %5, %6, %7) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0
%9 : Float(requires_grad=0, device=cpu) = aten::sqrt(%x) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0
%10 : Double(requires_grad=0, device=cpu) = prim::Constant[value={5}]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0
%11 : Float(requires_grad=0, device=cpu) = aten::div(%9, %10) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0
%12 : Float(requires_grad=0, device=cpu) = aten::floor(%11) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0
return (%12)
ScriptModule graph graph(%self : __torch__.ScriptModule,
%x.1 : Tensor):
%5 : float = prim::Constant[value=2.]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:29
%9 : float = prim::Constant[value=0.]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:38
%r.1 : Tensor = aten::neg(%x.1) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:165:12
%6 : Tensor = aten::fmod(%x.1, %5) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:15
%8 : int = aten::Int(%6) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:11
%10 : bool = aten::eq(%8, %9) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:11
%r : Tensor = prim::If(%10) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:8
block0():
%r.3 : Tensor = aten::div(%x.1, %5) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:167:16
-> (%r.3)
block1():
-> (%r.1)
return (%r)
tensor([190.])
跟踪顶层模型#
本例的最后一部分是跟踪顶层模块 Net
。如前所述,由于跟踪/脚本化模块是 Net 的属性,我们能够跟踪 Net
,因为它继承了跟踪/脚本化子模块的参数。请注意,跟踪 Net 的语法与跟踪 TracedModule
的语法相同。此外,请查看创建的图。
n_traced = torch.jit.trace(n, torch.tensor([5]))
print(n_traced(torch.tensor([5])))
print('n_traced graph', n_traced.graph)
tensor([190.])
n_traced graph graph(%self : __torch__.Net,
%x.1 : Long(1, strides=[1], requires_grad=0, device=cpu)):
%script_module : __torch__.ScriptModule = prim::GetAttr[name="script_module"](%self)
%traced_module : __torch__.TracedModule = prim::GetAttr[name="traced_module"](%self)
%10 : Function = prim::Constant[name="fn"]()
%x : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::CallFunction(%10, %x.1)
%12 : Function = prim::Constant[name="script_fn"]()
%13 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::CallFunction(%12, %x)
%14 : Float(1, strides=[1], requires_grad=0, device=cpu) = prim::CallMethod[name="forward"](%traced_module, %13)
%15 : Float(1, strides=[1], requires_grad=0, device=cpu) = prim::CallMethod[name="forward"](%script_module, %14)
return (%15)
希望本文能为混合前端提供一个介绍,并为更有经验的用户提供一个语法参考指南。此外,在使用混合前端时,有几点需要记住。跟踪/脚本化方法必须以 Python 的受限子集编写,因为不支持生成器、定义和 Python 数据结构等功能。作为一种变通方法,脚本化模型 *旨在* 与跟踪和非跟踪代码一起工作,这意味着您可以从跟踪函数中调用非跟踪代码。但是,此类模型可能无法导出到 ONNX。
脚本总运行时间: (0 分钟 0.081 秒)