tensordict.nn.distributions.CompositeDistribution¶
- class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: Optional[dict] = None, extra_kwargs=None, log_prob_key: Optional[NestedKey] = None, entropy_key: Optional[NestedKey] = None)¶
一个复合分布,使用 TensorDict 接口将多个分布分组在一起。
此类允许对一组分布执行诸如 `log_prob_composite`、`entropy_composite`、`cdf`、`icdf`、`rsample` 和 `sample` 等操作,并返回一个 TensorDict。输入的 TensorDict 可能会被就地修改。
- 参数:
params (TensorDictBase) – 一个嵌套的键-张量映射,其中根条目对应于样本名称,叶子是分布参数。条目名称必须与 `distribution_map` 中指定的名称匹配。
distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – 指定要使用的分布类型。分布的名称应与 `TensorDict` 中的样本名称匹配。
- 关键字参数:
name_map (Dict[NestedKey, NestedKey], optional) – 一个映射,指定每个样本应写入的位置。如果未提供,将使用 `distribution_map` 中的键名。
extra_kwargs (Dict[NestedKey, Dict], optional) – 一个字典,包含用于构造分布的附加关键字参数。
log_prob_key (NestedKey, optional) –
将存储聚合对数概率的键。默认为 `‘sample_log_prob’`。
注意
如果 `tensordict.nn.probabilistic.composite_lp_aggregate()` 返回 `False`,则对数概率将存储在 `(“path”, “to”, “leaf”, “<sample_name>_log_prob”)` 下,其中 `(“path”, “to”, “leaf”, “<sample_name>”)` 是与正在采样的叶子张量对应的 `NestedKey`。在这种情况下,将忽略 `log_prob_key` 参数。
entropy_key (NestedKey, optional) –
将存储熵的键。默认为 `‘entropy’`。
注意
如果 `tensordict.nn.probabilistic.composite_lp_aggregate()` 返回 `False`,则熵将存储在 `(“path”, “to”, “leaf”, “<sample_name>_entropy”)` 下,其中 `(“path”, “to”, “leaf”, “<sample_name>”)` 是与正在采样的叶子张量对应的 `NestedKey`。在这种情况下,将忽略 `entropy_key` 参数。
注意
包含参数(`params`)的输入 TensorDict 的批处理大小决定了分布的批处理形状。例如,调用 `log_prob` 产生的 `“sample_log_prob”` 条目的形状将是参数的形状加上任何额外的批处理维度。
另请参阅
ProbabilisticTensorDictModule
和ProbabilisticTensorDictSequential
以了解如何将此类用作模型的一部分。另请参阅
set_composite_lp_aggregate
以控制对数概率的聚合。示例
>>> params = TensorDict({ ... "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)}, ... ("nested", "disc"): {"logits": torch.randn(3, 10)} ... }, [3]) >>> dist = CompositeDistribution(params, ... distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical}) >>> sample = dist.sample((4,)) >>> with set_composite_lp_aggregate(False): ... sample = dist.log_prob(sample) ... print(sample) TensorDict( fields={ cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False), disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)