注意
跳转到结尾 下载完整的示例代码。
导出 tensordict 模块¶
先决条件¶
最好先阅读 TensorDictModule 教程,以充分受益于本教程。
一旦使用 tensordict.nn
编写了模块,通常需要隔离计算图并导出该图。其目标可能是为了在硬件(例如机器人、无人机、边缘设备)上执行模型,或者完全消除对 tensordict 的依赖。
PyTorch 提供了多种导出模块的方法,包括 onnx
和 torch.export
,这两者都与 tensordict
兼容。
在本简短教程中,我们将了解如何使用 torch.export
来隔离模型的计算图。torch.onnx
支持遵循相同的逻辑。
关键学习点¶
在没有
TensorDict
输入的情况下执行tensordict.nn
模块;选择模型的输出;
处理随机模型;
使用 torch.export 导出此类模型;
将模型保存到文件;
隔离 pytorch 模型;
import time
import torch
from tensordict.nn import (
InteractionType,
NormalParamExtractor,
ProbabilisticTensorDictModule as Prob,
set_interaction_type,
TensorDictModule as Mod,
TensorDictSequential as Seq,
)
from torch import distributions as dists, nn
设计模型¶
在许多应用中,使用随机模型很有用,即输出一个不是确定定义的,而是根据参数化分布进行采样的变量的模型。例如,生成式 AI 模型在提供相同输入时通常会生成不同的输出,因为它们根据由输入定义的参数的分布进行采样。
tensordict
库通过 ProbabilisticTensorDictModule
类来处理这个问题。这个基元是使用分布类(在我们的例子中是 Normal
)和将在执行时用于构建该分布的输入键的指示器构建的。
因此,我们正在构建的网络将是三个主要组件的组合:
将输入映射到潜在参数的网络;
一个
tensordict.nn.NormalParamExtractor
模块,将输入分割为要传递给Normal
分布的位置 “loc” 和 “scale” 参数;一个分布构造模块。
model = Seq(
# 1. A small network for embedding
Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
# 2. Extracting params
Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
# 3. Probabilistic module
Prob(
in_keys=["loc", "scale"],
out_keys=["sample"],
distribution_class=dists.Normal,
),
)
让我们运行这个模型,看看输出是什么样的。
x = torch.randn(1, 3)
print(model(x=x))
(tensor([[0.7624, 0.2919, 0.3075, 0.1171]], grad_fn=<ReluBackward0>), tensor([[ 0.2043, -0.5592, 0.3365, 0.3537]], grad_fn=<AddmmBackward0>), tensor([[ 0.2043, -0.5592]], grad_fn=<SplitBackward0>), tensor([[1.2243, 1.2364]], grad_fn=<ClampMinBackward0>), tensor([[ 0.2043, -0.5592]], grad_fn=<SplitBackward0>))
正如预期的那样,使用张量输入运行模型会返回与模块的输出键一样多的张量!对于大型模型来说,这可能相当烦人且浪费。稍后,我们将看到如何限制模型的输出数量来解决这个问题。
将 torch.export
与 TensorDictModule
结合使用¶
现在我们已经成功构建了模型,我们希望将其计算图提取到一个独立于 tensordict
的单个对象中。torch.export
是一个专门用于隔离模块图并以标准化方式表示它的 PyTorch 模块。它的主要入口点是 export()
,它返回一个 ExportedProgram
对象。反过来,该对象有几个我们将在下面探讨的感兴趣的属性:一个 graph_module
,它表示 export
捕获的 FX 图;一个 graph_signature
,包含图的输入、输出等;最后是一个 module()
,它返回一个可以替代原始模块的可调用对象。
虽然我们的模块接受 args 和 kwargs,但我们将重点关注其 kwargs 的用法,因为这样更清晰。
from torch.export import export
model_export = export(model, args=(), kwargs={"x": x}, strict=True)
让我们看看这个模块。
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = None
return pytree.tree_unflatten((relu, linear_1, getitem_2, getitem_3, getitem_2), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
该模块可以像我们的原始模块一样运行(开销更低)。
Time for TDModule: 519.28 micro-seconds
Time for exported module: 671.15 micro-seconds
以及 FX 图。
print("fx graph:", model_export.graph_module.print_readable())
class GraphModule(torch.nn.Module):
def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/tensordict/nn/common.py:1133 in _call_module, code: out = self.module(*tensors, **kwargs)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias); x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias); p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:85 in forward, code: loc, scale = tensor.chunk(2, -1)
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem: "f32[1, 2]" = chunk[0]
getitem_1: "f32[1, 2]" = chunk[1]; chunk = None
# File: /pytorch/tensordict/tensordict/nn/utils.py:70 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:86 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:58 in broadcast_all, code: return torch.broadcast_tensors(*values)
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2: "f32[1, 2]" = broadcast_tensors[0]
getitem_3: "f32[1, 2]" = broadcast_tensors[1]; broadcast_tensors = None
return (relu, linear_1, getitem_2, getitem_3, getitem_2)
fx graph: class GraphModule(torch.nn.Module):
def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/tensordict/nn/common.py:1133 in _call_module, code: out = self.module(*tensors, **kwargs)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias); x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias); p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:85 in forward, code: loc, scale = tensor.chunk(2, -1)
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem: "f32[1, 2]" = chunk[0]
getitem_1: "f32[1, 2]" = chunk[1]; chunk = None
# File: /pytorch/tensordict/tensordict/nn/utils.py:70 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:86 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:58 in broadcast_all, code: return torch.broadcast_tensors(*values)
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2: "f32[1, 2]" = broadcast_tensors[0]
getitem_3: "f32[1, 2]" = broadcast_tensors[1]; broadcast_tensors = None
return (relu, linear_1, getitem_2, getitem_3, getitem_2)
处理嵌套键¶
嵌套键是 tensordict 库的核心功能,因此能够导出读取和写入嵌套条目的模块是一项重要的支持功能。由于关键字参数必须是常规字符串,因此 dispatch
无法直接使用它们。相反,dispatch
将解包由常规下划线(“_”)连接的嵌套键,如下面的示例所示。
model_nested = Seq(
Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]),
Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]),
).select_out_keys(("some", "output"))
model_nested_export = export(model_nested, args=(), kwargs={"some_key": x})
print("exported module with nested input:", model_nested_export.module())
exported module with nested input: GraphModule()
def forward(self, some_key):
some_key, = fx_pytree.tree_flatten_spec(([], {'some_key':some_key}), self._in_spec)
add = torch.ops.aten.add.Tensor(some_key, 1); some_key = None
sub = torch.ops.aten.sub.Tensor(add, 1); add = None
return pytree.tree_unflatten((sub,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
请注意,由 module() 返回的可调用对象是纯 Python 可调用对象,可以进一步使用 compile()
进行编译。
保存导出的模块¶
torch.export
有自己的序列化协议,save()
和 load()
。通常,应使用 “.pt2” 扩展名。
>>> torch.export.save(model_export, "model.pt2")
选择输出¶
回想一下,tensordict.nn
的目的是保留所有中间值作为输出,除非用户明确要求只获取特定值。在训练期间,这可能非常有用:可以轻松记录图的中间值,或将它们用于其他目的(例如,根据保存的参数重构分布,而不是保存 Distribution
对象本身)。也可以认为,在训练期间,注册中间值对内存的影响可以忽略不计,因为它们是 torch.autograd
用于计算参数梯度的计算图的一部分。
然而,在推理期间,我们最感兴趣的可能是模型的最终样本。由于我们希望提取模型以供独立于 tensordict
库使用的用途,因此隔离我们想要的唯一输出来得有意义。为此,我们有几种选择:
使用
selected_out_keys
关键字参数构建TensorDictSequential()
,这将导致在调用模块期间选择所需的条目;使用
select_out_keys()
方法,该方法将原地修改out_keys
属性(可以通过reset_out_keys()
恢复)。将现有实例包装在
TensorDictSequential()
中,该实例将过滤掉不需要的键。>>> module_filtered = Seq(module, selected_out_keys=["sample"])
让我们在选择输出键后测试模型。当提供 x 输入时,我们期望我们的模型输出一个对应于分布样本的单个张量。
tensor([[ 0.2043, -0.5592]], grad_fn=<SplitBackward0>)
我们看到输出现在是一个单个张量,对应于分布的样本。我们可以从此创建一个新的导出图。其计算图应该被简化。
model_export = export(model, args=(), kwargs={"x": x})
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1); linear_1 = None
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = getitem_3 = None
return pytree.tree_unflatten((getitem_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
控制采样策略¶
我们还没有讨论 ProbabilisticTensorDictModule
如何从分布中采样。通过采样,我们指的是根据特定策略获取定义在分布空间内的值。例如,在训练期间可能希望获得随机样本,而在推理时间获得确定性样本(例如,均值或众数)。为了解决这个问题,tensordict
利用 set_interaction_type
装饰器和上下文管理器,它接受 InteractionType
枚举输入。
>>> with set_interaction_type(InteractionType.MEAN):
... output = module(input) # takes the input of the distribution, if ProbabilisticTensorDictModule is invoked
默认的 InteractionType
是 InteractionType.DETERMINISTIC
,如果它没有直接实现,它要么是具有实数域的分布的均值,要么是具有离散域的分布的众数。此默认值可以通过 ProbabilisticTensorDictModule
的 default_interaction_type
关键字参数来更改。
总而言之:要控制网络的采样策略,我们可以在构造函数中定义默认采样策略,或者通过 set_interaction_type
上下文管理器在运行时覆盖它。
从下面的示例可以看出,torch.export
正确响应了装饰器的使用:如果我们要求随机样本,输出将与要求均值时的输出不同。
with set_interaction_type(InteractionType.RANDOM):
model_export = export(model, args=(), kwargs={"x": x})
print(model_export.module())
with set_interaction_type(InteractionType.MEAN):
model_export = export(model, args=(), kwargs={"x": x})
print(model_export.module())
GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1); linear_1 = None
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = None
empty = torch.ops.aten.empty.memory_format([1, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
normal_ = torch.ops.aten.normal_.default(empty); empty = None
mul = torch.ops.aten.mul.Tensor(normal_, getitem_3); normal_ = getitem_3 = None
add_2 = torch.ops.aten.add.Tensor(getitem_2, mul); getitem_2 = mul = None
return pytree.tree_unflatten((add_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1); linear_1 = None
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = getitem_3 = None
return pytree.tree_unflatten((getitem_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
以上是使用 torch.export
所需了解的所有信息。有关更多信息,请参阅 官方文档。
后续步骤和扩展阅读¶
查看
torch.export
教程,可在 此处 找到;ONNX 支持:请参阅 ONNX 教程 了解此功能的更多信息。导出到 ONNX 与此处解释的 torch.export 非常相似。
要在没有 Python 环境的服务器上部署 PyTorch 代码,请查看 AOTInductor 文档。
脚本总运行时间: (0 分 5.126 秒)