快捷方式

tensordict.nn 包

tensordict.nn 包使得在 ML 流水线中灵活使用 TensorDict 成为可能。

由于 TensorDict 将代码的某些部分转换为基于键的结构,因此现在可以使用这些键作为钩子来构建复杂的图结构。基本构建块是 TensorDictModule,它用一组输入和输出键包装了一个 torch.nn.Module 实例。

>>> from torch.nn import Transformer
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> import torch
>>> module = TensorDictModule(Transformer(), in_keys=["feature", "target"], out_keys=["prediction"])
>>> data = TensorDict({"feature": torch.randn(10, 11, 512), "target": torch.randn(10, 11, 512)}, [10, 11])
>>> data = module(data)
>>> print(data)
TensorDict(
    fields={
        feature: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
        prediction: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
        target: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32)},
    batch_size=torch.Size([10, 11]),
    device=None,
    is_shared=False)

不一定需要使用 TensorDictModule,一个具有有序输入和输出键(名为 module.in_keysmodule.out_keys)的自定义 torch.nn.Module 就足够了。

许多 PyTorch 用户的一个痛点是 nn.Sequential 无法处理具有多个输入的模块。使用基于键的图可以轻松解决此问题,因为序列中的每个节点都知道需要读取哪些数据以及将数据写入何处。

为此,我们提供了 TensorDictSequential 类,它将数据传递给 TensorDictModules 序列。序列中的每个模块都从原始 TensorDict 获取输入,并将其输出写入 TensorDict,这意味着序列中的模块可以忽略其前驱的输出,或者根据需要从 TensorDict 获取额外的输入。示例如下:

>>> from tensordict.nn import TensorDictSequential
>>> class Net(nn.Module):
...     def __init__(self, input_size=100, hidden_size=50, output_size=10):
...         super().__init__()
...         self.fc1 = nn.Linear(input_size, hidden_size)
...         self.fc2 = nn.Linear(hidden_size, output_size)
...
...     def forward(self, x):
...         x = torch.relu(self.fc1(x))
...         return self.fc2(x)
...
>>> class Masker(nn.Module):
...     def forward(self, x, mask):
...         return torch.softmax(x * mask, dim=1)
...
>>> net = TensorDictModule(
...     Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
...     Masker(),
...     in_keys=[("intermediate", "x"), ("input", "mask")],
...     out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>>
>>> td = TensorDict(
...     {
...         "input": TensorDict(
...             {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
...             batch_size=[32],
...         )
...     },
...     batch_size=[32],
... )
>>> td = module(td)
>>> print(td)
TensorDict(
    fields={
        input: TensorDict(
            fields={
                mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
                x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        intermediate: TensorDict(
            fields={
                x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        output: TensorDict(
            fields={
                probabilities: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)

我们还可以通过 select_subsequence() 方法轻松选择子图。

>>> sub_module = module.select_subsequence(out_keys=[("intermediate", "x")])
>>> td = TensorDict(
...     {
...         "input": TensorDict(
...             {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
...             batch_size=[32],
...         )
...     },
...     batch_size=[32],
... )
>>> sub_module(td)
>>> print(td)  # the "output" has not been computed
TensorDict(
    fields={
        input: TensorDict(
            fields={
                mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
                x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        intermediate: TensorDict(
            fields={
                x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)

最后,tensordict.nn 提供了一个 ProbabilisticTensorDictModule,它允许根据网络输出来构建分布,并从中获取摘要统计信息或样本(以及分布参数)。

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.prototype import (
...     ProbabilisticTensorDictModule,
...     ProbabilisticTensorDictSequential,
... )
>>> from torch.distributions import Normal
>>> td = TensorDict(
...     {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.Sequential(torch.nn.GRUCell(4, 8), NormalParamExtractor())
>>> module = TensorDictModule(
...     net, in_keys=["input", "hidden"], out_keys=["loc", "scale"]
... )
>>> prob_module = ProbabilisticTensorDictModule(
...     in_keys=["loc", "scale"],
...     out_keys=["sample"],
...     distribution_class=Normal,
...     return_log_prob=True,
... )
>>> td_module = ProbabilisticTensorDictSequential(module, prob_module)
>>> td_module(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),
        input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        sample_log_prob: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

TensorDictModuleBase(*args, **kwargs)

TensorDict 模块的基类。

TensorDictModule(*args, **kwargs)

TensorDictModule 是 nn.Module 的 Python 包装器,用于读写 TensorDict。

ProbabilisticTensorDictModule(*args, **kwargs)

概率 TD 模块。

ProbabilisticTensorDictSequential(*args, ...)

一个包含至少一个 ProbabilisticTensorDictModuleTensorDictModules 序列。

TensorDictSequential(*args, **kwargs)

TensorDictModules 的序列。

TensorDictModuleWrapper(*args, **kwargs)

TensorDictModule 对象的包装类。

CudaGraphModule(module[, warmup, in_keys, ...])

PyTorch 可调用对象的 cudagraph 包装器。

WrapModule(*args, **kwargs)

处理 TensorDict 实例的任何可调用对象的包装器。

set_composite_lp_aggregate([mode])

控制 CompositeDistribution 的对数概率和熵是否将被聚合到单个张量中。

composite_lp_aggregate([nowarn])

返回 CompositeDistribution 的对数概率和熵是否将被聚合到单个张量中。

as_tensordict_module(*, in_keys, out_keys)

将函数转换为 TensorDictModule 的装饰器。

集成

函数式方法使得实现集成变得简单。我们可以使用 tensordict.nn.EnsembleModule 来复制和重新初始化模型副本。

>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
>>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b'])
>>> ensemble = EnsembleModule(mod, num_copies=3)
>>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10])
>>> ensemble(data)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 10]),
    device=None,
    is_shared=False)

EnsembleModule(*args, **kwargs)

包装模块并重复使用它来形成集成的模块。

编译 TensorDictModules

自 v0.5 起,TensorDict 组件与 compile() 兼容。例如,可以通过 torch.compile 编译 TensorDictSequential 模块,并达到与包装在 TensorDictModule 中的普通 PyTorch 模块相似的运行时性能。

分布

AddStateIndependentNormalScale([...])

一个添加可训练的与状态无关的尺度参数的 nn.Module。

CompositeDistribution(params, ...[, ...])

一个复合分布,使用 TensorDict 接口将多个分布组合在一起。

Delta(param[, atol, rtol, batch_shape, ...])

Delta 分布。

NormalParamExtractor([scale_mapping, scale_lb])

一个非参数 nn.Module,它将输入分割为 loc 和 scale 参数。

OneHotCategorical([logits, probs])

独热(One-hot)分类分布。

TruncatedNormal(loc, scale, a, b[, ...])

截断正态分布。

工具

make_tensordict([input_dict, batch_size, ...])

从关键字参数或输入字典返回一个创建的 TensorDict。

dispatch([separator, source, dest, ...])

允许使用 kwargs 调用期望 TensorDict 的函数。

inv_softplus(bias)

反向 softplus 函数。

biased_softplus(bias[, min_val])

带偏置的 softplus 模块。

set_skip_existing([mode, in_key_attr, ...])

用于在 TensorDict 图中跳过现有节点的上下文管理器。

skip_existing()

返回一个模块是否应该重新计算 tensordict 中的现有条目。

rand_one_hot(values[, do_softmax])

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源