快捷方式

MaskedCategorical

class torchrl.modules.MaskedCategorical(logits: torch.Tensor | None = None, probs: torch.Tensor | None = None, *, mask: torch.Tensor | None = None, indices: torch.Tensor | None = None, neg_inf: float = - inf, padding_value: int | None = None, use_cross_entropy: bool = True, padding_side: str = 'left')[源代码]

MaskedCategorical 分布。

参考: https://tensorflowcn.cn/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical

参数:
  • logits (torch.Tensor) – 事件的对数概率(未归一化)

  • probs (torch.Tensor) – 事件概率。如果提供了概率,则与掩码项对应的概率将被归零,并在最后一个维度上重新归一化概率。

关键字参数:
  • mask (torch.Tensor) – 与 logits/probs 形状相同的布尔掩码,其中 False 条目是被掩码的。或者,如果 sparse_mask 为 True,则它表示分布中有效索引的列表。与 indices 互斥。

  • indices (torch.Tensor) – 一个密集索引张量,表示必须考虑哪些动作。与 mask 互斥。

  • neg_inf (float, optional) – 分配给无效(掩码外)索引的对数概率值。默认为 -inf。

  • padding_value – 掩码张量中的填充值。当 sparse_mask == True 时,将忽略 padding_value。

  • use_cross_entropy (bool, optional) – 为了更快地计算对数概率,可以使用 cross_entropy 损失函数。默认为 True

  • padding_side (str, optional) – 填充的侧面。默认为 "left"

示例

>>> torch.manual_seed(0)
>>> logits = torch.randn(4) / 100  # almost equal probabilities
>>> mask = torch.tensor([True, False, True, True])
>>> dist = MaskedCategorical(logits=logits, mask=mask)
>>> sample = dist.sample((10,))
>>> print(sample)  # no `1` in the sample
tensor([2, 3, 0, 2, 2, 0, 2, 0, 2, 2])
>>> print(dist.log_prob(sample))
tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203, -1.0831,
        -1.1203, -1.1203])
>>> print(dist.log_prob(torch.ones_like(sample)))
tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
>>> # with probabilities
>>> prob = torch.ones(10)
>>> prob = prob / prob.sum()
>>> mask = torch.tensor([False] + 9 * [True])  # first outcome is masked
>>> dist = MaskedCategorical(probs=prob, mask=mask)
>>> print(dist.log_prob(torch.arange(10)))
tensor([   -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,
        -2.1972, -2.1972])
entropy()[源代码]

计算分布的熵。

对于掩码分布,我们只考虑有效(未掩码)结果的熵。无效结果的概率为零,不计入熵。

log_prob(value: Tensor) Tensor[源代码]

返回在 value 处评估的概率密度/质量函数的对数。

参数:

value (Tensor) –

property padding_value

分布掩码的填充值。

如果未设置填充值,则会从 logits 中推断。

sample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor[源代码]

生成 sample_shape 形状的样本,如果分布参数是批处理的,则生成 sample_shape 形状的样本批次。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源