快捷方式

SafeProbabilisticModule

class torchrl.modules.tensordict_module.SafeProbabilisticModule(*args, **kwargs)[源码]

tensordict.nn.ProbabilisticTensorDictModule 的子类,它接受一个 TensorSpec 作为参数来控制输出域。

SafeProbabilisticModule 是一个非参数模块,封装了一个概率分布构造器。它从输入的 TensorDict 中读取分布参数,使用指定的 in_keys,并输出该分布的一个(宽泛意义上的)样本。

输出“样本”是根据一个规则生成的,该规则由输入参数 default_interaction_type 和全局函数 interaction_type() 指定。

SafeProbabilisticModule 可用于构造分布(通过 get_dist() 方法)和/或从中采样(通过对模块的常规 __call__() 调用)。

一个 SafeProbabilisticModule 实例具有两个主要特性:

  • 它从 TensorDict 对象读取并写入数据;

  • 它使用一个真实的映射 R^n -> R^m 来创建一个 R^d 中的分布,从中可以采样或计算值。

当调用 __call__()forward() 方法时,会创建一个分布,并计算一个值(根据 interaction_type 的值,可以使用 'dist.mean'、'dist.mode'、'dist.median' 属性,以及 'dist.rsample'、'dist.sample' 方法)。如果提供的 TensorDict 已包含所有期望的键值对,则会跳过采样步骤。

默认情况下,SafeProbabilisticModule 的分布类是 Delta 分布,这使得 SafeProbabilisticModule 成为一个简单的确定性映射函数包装器。

此类与 tensordict.nn.ProbabilisticTensorDictModule 的区别在于,它接受一个 spec 关键字参数,该参数可用于控制样本是否属于分布。 safe 关键字参数控制样本值是否应根据 spec 进行检查。

参数:
  • in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]) – 将从输入的 TensorDict 读取并用于构建分布的键。重要的是,如果它是 NestedKey 列表或 NestedKey,这些键的叶子(最后一个元素)必须与所关注的分布类使用的关键字匹配,例如,对于 Normal 分布,关键字为 "loc""scale",依此类推。如果 in_keys 是一个字典,键是分布的键,值是 tensordict 中将匹配相应分布键的键。

  • out_keys (NestedKey | List[NestedKey] | None) – 将写入采样值的键。重要的是,如果这些键在输入的 TensorDict 中找到,则会跳过采样步骤。

  • spec (TensorSpec) – 第一个输出张量的 spec。在调用 td_module.random() 生成目标空间中的随机值时使用。

关键字参数:
  • safe (bool, optional) – 如果为 True,则样本的值将根据输入 spec 进行检查。由于探索策略或数值下溢/上溢问题,可能会出现域外采样。与 spec 参数一样,此检查仅针对分布样本进行,而不是输入模块返回的其他张量。如果样本超出范围,则使用 TensorSpec.project 方法将其投影回所需空间。默认为 False

  • default_interaction_type (InteractionType, optional) –

    仅关键字参数。用于检索输出值的默认方法。应为 InteractionType 中的一个:MODE、MEDIAN、MEAN 或 RANDOM(在这种情况下,值将从分布中随机采样)。默认为 MODE。

    注意

    当绘制样本时,ProbabilisticTensorDictModule 实例将首先查找由全局函数 interaction_type() 指定的交互模式。如果此函数返回 None(其默认值),则将使用 ProbabilisticTDModule 实例的 default_interaction_type。请注意,DataCollectorBase 实例将默认使用 set_interaction_type 设置为 tensordict.nn.InteractionType.RANDOM

    注意

    在某些情况下,可能无法通过相应的属性直接获取 mode、median 或 mean 值。为解决此问题,ProbabilisticTensorDictModule 将首先尝试通过调用 get_mode()get_median()get_mean() 来获取值(如果方法存在)。

  • distribution_class (Type or Callable[[Any], Distribution], optional) –

    仅关键字参数。用于采样的 torch.distributions.Distribution 类。默认为 Delta

    注意

    如果 distribution_class 是 CompositeDistribution 类型,则 out_keys 可以直接从该类的 distribution_kwargs 关键字参数中通过 "distribution_map""name_map" 推断出来,在这种情况下 out_keys 是可选的。

  • distribution_kwargs (dict, optional) –

    仅关键字参数。要传递给分布的关键字参数对。

    注意

    如果您的 kwargs 包含您希望与模块一起传输到设备的张量,或者在调用 module.to(dtype) 时其 dtype 应被修改的张量,您可以将 kwargs 包装在 TensorDictParams 中以自动完成此操作。

  • return_log_prob (bool, optional) – 仅关键字参数。如果为 True,则分布样本的对数概率将以 log_prob_key 的键写入 tensordict。默认为 False

  • log_prob_keys (List[NestedKey], optional) –

    如果 return_log_prob=True,则写入 log_prob 的键。默认为 ‘<sample_key_name>_log_prob’,其中 <sample_key_name>out_keys 的每个元素。

    注意

    这仅在 composite_lp_aggregate() 设置为 False 时可用。

  • log_prob_key (NestedKey, optional) –

    如果 return_log_prob=True,则写入 log_prob 的键。默认为 ‘sample_log_prob’(当 composite_lp_aggregate() 设置为 True 时)或 ‘<sample_key_name>_log_prob’(否则)。

    注意

    当有多个样本时,这仅在 composite_lp_aggregate() 设置为 True 时可用。

  • cache_dist (bool, optional) – 仅关键字参数。实验性质:如果为 True,则分布的参数(即模块的输出)将与样本一起写入 tensordict。这些参数可用于以后重新计算原始分布(例如,在 PPO 中计算用于采样动作的分布与更新后的分布之间的散度)。默认为 False

  • n_empirical_estimate (int, optional) – 仅关键字参数。在均值不可用时计算经验均值的样本数量。默认为 1000。

警告

运行检查需要时间!使用 safe=True 将保证样本在 spec 边界内,这依赖于 project() 中编码的一些启发式方法,但这需要检查值是否在 spec 空间内,这将产生一些开销。

另请参阅

tensordict.nn.CompositeDistribution(复合分布)可用于创建多头策略。

示例

>>> from torchrl.modules import SafeProbabilisticModule
>>> from torchrl.data import Bounded
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import InteractionType
>>> mod = SafeProbabilisticModule(
...     in_keys=["loc", "scale"],
...     out_keys=["action"],
...     distribution_class=torch.distributions.Normal,
...     safe=True,
...     spec=Bounded(low=-1, high=1, shape=()),
...     default_interaction_type=InteractionType.RANDOM
... )
>>> _ = torch.manual_seed(0)
>>> data = TensorDict(
...     loc=torch.zeros(10, requires_grad=True),
...     scale=torch.full((10,), 10.0),
...     batch_size=(10,))
>>> data = mod(data)
>>> print(data["action"]) # All actions are within bound
tensor([ 1., -1., -1.,  1., -1., -1.,  1.,  1., -1., -1.],
       grad_fn=<ClampBackward0>)
>>> data["action"].mean().backward()
>>> print(data["loc"].grad) # clamp anihilates gradients
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
random(tensordict: TensorDictBase) TensorDictBase[源码]

独立于任何输入,在目标空间中随机采样一个元素。

如果存在多个输出键,则只有第一个键会写入输入的 tensordict 中。

参数:

tensordict (TensorDictBase) – 应将输出值写入的 tensordict。

返回:

包含输出键的新/更新值的原始 tensordict。

random_sample(tensordict: TensorDictBase) TensorDictBase[源码]

请参阅 SafeModule.random(...)

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源