快捷方式

ProbabilisticTensorDictModule

class tensordict.nn.ProbabilisticTensorDictModule(*args, **kwargs)

一个概率性 TD 模块。

ProbabilisticTensorDictModule 是一个嵌入概率分布构造函数的非参数模块。它使用指定的 in_keys 从输入 TensorDict 中读取分布参数,并输出(广义上)分布的样本。

输出的“样本”是根据一个规则生成的,该规则由输入参数 default_interaction_type 和全局函数 interaction_type() 指定。

ProbabilisticTensorDictModule 可用于构造分布(通过 get_dist() 方法)和/或从此分布中采样(通过对模块的常规 __call__() 调用)。

一个 ProbabilisticTensorDictModule 实例有两个主要特性:

  • 它从 TensorDict 对象中读写;

  • 它使用一个实数映射 R^n -> R^m 来创建一个 R^d 中的分布,可以从中采样或计算值。

当调用 __call__()forward() 方法时,会创建一个分布,并计算一个值(根据 interaction_type 的值,可以使用 'dist.mean'、'dist.mode'、'dist.median' 属性,以及 'dist.rsample'、'dist.sample' 方法)。如果提供的 TensorDict 已包含所有期望的键值对,则会跳过采样步骤。

默认情况下,ProbabilisticTensorDictModule 的分布类是 Delta 分布,这使得 ProbabilisticTensorDictModule 成为一个简单的确定性映射函数包装器。

参数:
  • in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]) – 将从输入 TensorDict 读取并用于构建分布的键。重要的是,如果它是 NestedKey 的列表或 NestedKey,则这些键的叶子(最后一个元素)必须匹配感兴趣的分布类使用的关键字,例如,对于 Normal 分布,关键字是 "loc""scale",依此类推。如果 in_keys 是一个字典,键是分布的键,值是 tensordict 中将与相应分布键匹配的键。

  • out_keys (NestedKey | List[NestedKey] | None) – 采样值将被写入的键。重要的是,如果这些键存在于输入 TensorDict 中,则会跳过采样步骤。

关键字参数:
  • default_interaction_type (InteractionType, optional) –

    仅关键字参数。用于检索输出值的默认方法。应为 InteractionType 中的一个:MODE、MEDIAN、MEAN 或 RANDOM(在这种情况下,值是从分布中随机采样的)。默认为 MODE。

    注意

    当绘制样本时,ProbabilisticTensorDictModule 实例将首先查找由全局函数 interaction_type() 指定的交互模式。如果此函数返回 None(其默认值),则将使用 ProbabilisticTDModule 实例的 default_interaction_type。请注意,DataCollectorBase 实例将默认使用 set_interaction_type 设置为 tensordict.nn.InteractionType.RANDOM

    注意

    在某些情况下,模式 (mode)、中位数 (median) 或均值 (mean) 值可能无法通过相应的属性轻松获得。为解决此问题,如果方法存在,ProbabilisticTensorDictModule 将首先尝试通过调用 get_mode()get_median()get_mean() 来获取值。

  • distribution_class (Type or Callable[[Any], Distribution], optional) –

    仅关键字参数。用于采样的 torch.distributions.Distribution 类。默认为 Delta

    注意

    如果分布类是 CompositeDistribution 类型,则 out_keys 可以直接从通过该类的 distribution_kwargs 关键字参数提供的 "distribution_map""name_map" 键推断出来,在这种情况下 out_keys 是可选的。

  • distribution_kwargs (dict, optional) –

    仅关键字参数。将传递给分布的关键字参数对。

    注意

    如果您的 kwargs 包含您希望与模块一起传输到设备上的张量,或者在调用 module.to(dtype) 时应修改其 dtype 的张量,您可以将 kwargs 包装在 TensorDictParams 中以自动执行此操作。

  • return_log_prob (bool, optional) – 仅关键字参数。如果为 True,则对分布样本的对数概率将以 log_prob_key 键写入 tensordict。默认为 False

  • log_prob_keys (List[NestedKey], optional) –

    如果 return_log_prob=True,则写入 log_prob 的键。默认为 ‘<sample_key_name>_log_prob’,其中 <sample_key_name> 是每个 out_keys

    注意

    这仅在 composite_lp_aggregate() 设置为 False 时可用。

  • log_prob_key (NestedKey, optional) –

    如果 return_log_prob=True,则写入 log_prob 的键。当 composite_lp_aggregate() 设置为 True 时,默认为 ‘sample_log_prob’,否则为 ‘<sample_key_name>_log_prob’

    注意

    当有多个样本时,这仅在 composite_lp_aggregate() 设置为 True 时可用。

  • cache_dist (bool, optional) – 仅关键字参数。实验性:如果为 True,则分布的参数(即模块的输出)将与样本一起写入 tensordict。这些参数可用于稍后重新计算原始分布(例如,在 PPO 中计算用于采样动作的分布与更新后的分布之间的散度)。默认为 False

  • n_empirical_estimate (int, optional) – 仅关键字参数。在均值不可用时计算经验均值的样本数量。默认为 1000。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import (
...     ProbabilisticTensorDictModule,
...     ProbabilisticTensorDictSequential,
...     TensorDictModule,
... )
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch.distributions import Normal, Independent
>>> td = TensorDict(
...     {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.GRUCell(4, 8)
>>> module = TensorDictModule(
...     net, in_keys=["input", "hidden"], out_keys=["params"]
... )
>>> normal_params = TensorDictModule(
...     NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"]
... )
>>> def IndepNormal(**kwargs):
...     return Independent(Normal(**kwargs), 1)
>>> prob_module = ProbabilisticTensorDictModule(
...     in_keys=["loc", "scale"],
...     out_keys=["action"],
...     distribution_class=IndepNormal,
...     return_log_prob=True,
... )
>>> td_module = ProbabilisticTensorDictSequential(
...     module, normal_params, prob_module
... )
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
...     _ = td_module(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> with params.to_module(td_module):
...     dist = td_module.get_dist(td)
>>> print(dist)
Independent(Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])), 1)
>>> # we can also apply the module to the TensorDict with vmap
>>> from torch import vmap
>>> params = params.expand(4)
>>> def func(td, params):
...     with params.to_module(td_module):
...         return td_module(td)
>>> td_vmap = vmap(func, (None, 0))(td, params)
>>> print(td_vmap)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)
build_dist_from_params(tensordict: TensorDictBase) Distribution

使用输入 tensordict 中提供的参数创建一个 torch.distribution.Distribution 实例。

参数:

tensordict (TensorDictBase) – 包含分布参数的输入 tensordict。

返回:

一个从输入 tensordict 创建的 torch.distribution.Distribution 实例。

抛出:

TypeError – 如果输入 tensordict 与分布关键字不匹配。

property dist_params_keys: List[NestedKey]

返回指向分布参数的所有键。

property dist_sample_keys: List[NestedKey]

返回指向分布样本的所有键。

forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, _requires_sample: bool = True) TensorDictBase

定义每次调用时执行的计算。

所有子类都应重写此方法。

注意

尽管前向传播的逻辑需要在该函数中定义,但应在之后调用 Module 实例而不是此函数,因为前者负责运行已注册的钩子,而后者会默默忽略它们。

get_dist(tensordict: TensorDictBase) Distribution

使用输入 tensordict 中提供的参数创建一个 torch.distribution.Distribution 实例。

参数:

tensordict (TensorDictBase) – 包含分布参数的输入 tensordict。

返回:

一个从输入 tensordict 创建的 torch.distribution.Distribution 实例。

抛出:

TypeError – 如果输入 tensordict 与分布关键字不匹配。

log_prob(tensordict, *, dist: Optional[Distribution] = None)

计算分布样本的对数概率。

参数:
  • tensordict (TensorDictBase) – 包含分布参数的输入 tensordict。

  • dist (torch.distributions.Distribution, optional) – 分布实例。默认为 None。如果为 None,则使用 get_dist 方法计算分布。

返回:

表示分布样本对数概率的张量。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源