快捷方式

ProbabilisticTensorDictSequential

class tensordict.nn.ProbabilisticTensorDictSequential(*args, **kwargs)

包含至少一个 ProbabilisticTensorDictModuleTensorDictModules 序列。

此类扩展了 TensorDictSequential,通常配置为一系列模块,其中最后一个模块是 ProbabilisticTensorDictModule 的实例。但是,它也支持一种配置,即一个或多个中间模块是 ProbabilisticTensorDictModule 的实例,而最后一个模块可能是也可能不是概率性的。在所有情况下,它都公开了 get_dist() 方法,用于从序列中的 ProbabilisticTensorDictModule 实例中恢复分布对象。

多个概率模块可以共存于单个 ProbabilisticTensorDictSequential 中。如果 return_compositeFalse(默认值),则只有最后一个模块会生成分布,而其他模块将作为常规 TensorDictModule 实例执行。但是,如果 ProbabilisticTensorDictModule 不是序列中的最后一个模块且 return_composite=False,则在尝试查询模块时将引发 ValueError。如果 return_composite=True,则所有中间的 ProbabilisticTensorDictModule 实例都将贡献到一个 CompositeDistribution 实例。

如果样本相互依赖,则产生的对数概率将是条件概率:每当

\[Z = F(X, Y)\]

则 Z 的对数概率为

\[log(p(z | x, y))\]
参数:

*modules (sequenceOrderedDict of TensorDictModuleBaseProbabilisticTensorDictModule) – 一个有序的 TensorDictModule 实例序列,通常以 ProbabilisticTensorDictModule 结尾,将按顺序运行。模块可以是 TensorDictModuleBase 的实例,也可以是匹配此签名的任何其他可调用对象。请注意,如果使用非 TensorDictModuleBase 的可调用对象,其输入和输出键将不会被跟踪,因此也不会影响 TensorDictSequential 的 in_keysout_keys 属性。

关键字参数:
  • partial_tolerant (bool, optional) – 如果为 True,则输入 tensordict 可以缺少某些输入键。如果是这样,则只有根据存在的键可以执行的模块才会被执行。此外,如果输入 tensordict 是 tensordicts 的懒惰堆叠,并且 partial_tolerantTrue,并且堆叠不包含必需的键,则 TensorDictSequential 将扫描子 tensordicts,查找那些具有必需键的(如果有)。默认为 False

  • return_composite (bool, optional) – 如果 True 且找到多个 ProbabilisticTensorDictModuleProbabilisticTensorDictSequential 实例,则将使用 CompositeDistribution 实例。否则,将仅使用最后一个模块来构建分布。如果存在多个概率模块或最后一个模块不是概率性的,则默认为 True。如果 return_compositeFalse 且上述条件都不满足,则会引发错误。

  • inplace (bool, optional) – 如果为 True,则输入 tensordict 被就地修改。如果为 False,则会创建一个新的空 TensorDict 实例。如果为 “empty”,则使用 input.empty()(即输出保留类型、设备和批次大小)。默认为 None(依赖于子模块)。

抛出:

示例

>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq
>>> import torch
>>> # Typical usage: a single distribution is computed last in the sequence
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq,         ...     TensorDictModule as Mod
>>> torch.manual_seed(0)
>>>
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
... )
>>> input = TensorDict(x=torch.ones(3))
>>> td = module(input.copy())
>>> print(td)
TensorDict(
    fields={
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
Normal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> print(module.log_prob(td))
tensor([-0.9189, -0.9189, -0.9189])
>>> # Intermediate distributions are ignored when return_composite=False
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]),
...     Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     return_composite=False,
... )
>>> td = module(TensorDict(x=torch.ones(3)))
>>> print(td)
TensorDict(
    fields={
        loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
Normal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> print(module.log_prob(td))
tensor([-0.9189, -0.9189, -0.9189])
>>> # Intermediate distributions produce a CompositeDistribution when return_composite=True
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]),
...     Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     return_composite=True,
... )
>>> input = TensorDict(x=torch.ones(3))
>>> td = module(input.copy())
>>> print(td)
TensorDict(
    fields={
        loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3])), 'sample1': Normal(loc: torch.Size([3]), scale: torch.Size([3]))})
>>> print(module.log_prob(td))
TensorDict(
    fields={
        sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample1_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> # Even a single intermediate distribution is wrapped in a CompositeDistribution when
>>> # return_composite=True
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["y"]),
...     return_composite=True,
... )
>>> td = module(TensorDict(x=torch.ones(3)))
>>> print(td)
TensorDict(
    fields={
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3]))})
>>> print(module.log_prob(td))
TensorDict(
    fields={
        sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
build_dist_from_params(tensordict: TensorDictBase) Distribution

在不评估序列中其他模块的情况下,从输入参数构造分布。

此方法搜索序列中的最后一个 ProbabilisticTensorDictModule 并使用它来构建分布。

参数:

tensordict (TensorDictBase) – 包含分布参数的输入 tensordict。

返回:

构造的分布对象。

返回类型:

D.Distribution

抛出:

RuntimeError – 如果序列中未找到 ProbabilisticTensorDictModule

property default_interaction_type

通过迭代启发式方法返回模块的 default_interaction_type

此属性按反向顺序迭代所有模块,尝试从任何子模块检索 default_interaction_type 属性。遇到的第一个非 None 值将被返回。如果未找到此类值,则将返回默认的 interaction_type()

property dist_params_keys: List[NestedKey]

返回指向分布参数的所有键。

property dist_sample_keys: List[NestedKey]

返回指向分布样本的所有键。

forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs) TensorDictBase

当 tensordict 参数未设置时,kwargs 用于创建 TensorDict 的实例。

get_dist(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) Distribution

通过将输入 tensordict 通过序列传递来返回分布。

如果 return_compositeFalse(默认值),此方法将仅考虑序列中的最后一个概率模块。

否则,它将返回一个包含所有概率模块分布的 CompositeDistribution 实例。

参数:
  • tensordict (TensorDictBase) – 输入 tensordict。

  • tensordict_out (TensorDictBase, optional) – 输出 tensordict。如果为 None,将创建一个新的 tensordict。默认为 None

关键字参数:

**kwargs – 传递给底层模块的其他关键字参数。

返回:

结果分布对象。

返回类型:

D.Distribution

抛出:

RuntimeError – 如果序列中未找到概率模块。

注意

return_compositeTrue 时,分布将基于序列中的先前样本进行条件化。这意味着如果一个模块依赖于前一个概率模块的输出,那么它的分布将是条件性的。

get_dist_params(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, *, dist: Optional[Distribution] = None, **kwargs) tuple[torch.distributions.distribution.Distribution, tensordict.base.TensorDictBase]

返回分布参数和输出 tensordict。

此方法运行 ProbabilisticTensorDictSequential 模块的确定性部分以获取分布参数。如果可用,交互类型将设置为当前全局交互类型,否则默认为最后一个模块的交互类型。

参数:
  • tensordict (TensorDictBase) – 输入 tensordict。

  • tensordict_out (TensorDictBase, optional) – 输出 tensordict。如果为 None,将创建一个新的 tensordict。默认为 None

关键字参数:

**kwargs – 传递给模块确定性部分的附加关键字参数。

返回:

包含分布对象和输出 tensordict 的元组。

返回类型:

tuple[D.Distribution, TensorDictBase]

注意

交互类型在此方法执行期间会临时设置为指定值。

log_prob(tensordict, tensordict_out: Optional[TensorDictBase] = None, *, dist: Optional[Distribution] = None, **kwargs) tensordict.base.TensorDictBase | torch.Tensor

返回输入 tensordict 的对数概率。

如果 self.return_compositeTrue 且分布为 CompositeDistribution,则此方法将返回整个复合分布的对数概率。

否则,它将仅考虑序列中的最后一个概率模块。

参数:
  • tensordict (TensorDictBase) – 输入 tensordict。

  • tensordict_out (TensorDictBase, optional) – 输出 tensordict。如果为 None,将创建一个新的 tensordict。默认为 None

关键字参数:

dist (torch.distributions.Distribution, optional) – 分布对象。如果为 None,则将使用 get_dist 进行计算。默认为 None

返回:

输入 tensordict 的对数概率。

返回类型:

TensorDictBasetorch.Tensor

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源