快捷方式

OneHotCategorical

class torchrl.modules.OneHotCategorical(logits: torch.Tensor | None = None, probs: torch.Tensor | None = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs)[源代码]

独热(One-hot)分类分布。

此类的行为与 torch.distributions.Categorical 完全相同,但它读取和生成离散张量的 one-hot 编码。

参数:
  • logits (torch.Tensor) – 事件的对数概率(未归一化)

  • probs (torch.Tensor) – 事件的概率

  • grad_method (ReparamGradientStrategy, optional) –

    用于收集重参数化样本的策略。ReparamGradientStrategy.PassThrough 将使用 softmax 值的对数概率作为样本梯度的代理来计算样本梯度。

    使用 softmax 值的对数概率作为样本梯度的代理来计算样本梯度。

    ReparamGradientStrategy.RelaxedOneHot 将使用 torch.distributions.RelaxedOneHot 从分布中采样。

示例

>>> torch.manual_seed(0)
>>> logits = torch.randn(4)
>>> dist = OneHotCategorical(logits=logits)
>>> print(dist.rsample((3,)))
tensor([[1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.]])
entropy()[源代码]

返回分布的熵,按 batch_shape 批处理。

返回:

形状为 batch_shape 的张量。

log_prob(value: Tensor) Tensor[源代码]

返回在 value 处评估的概率密度/质量函数的对数。

参数:

value (Tensor) –

property mode: Tensor

返回分布的众数。

rsample(sample_shape: torch.Size | Sequence = None) torch.Tensor[源代码]

生成 sample_shape 形状的重参数化样本,如果分布参数是批处理的,则生成 sample_shape 形状的重参数化样本批次。

sample(sample_shape: torch.Size | Sequence | None = None) torch.Tensor[源代码]

生成 sample_shape 形状的样本,如果分布参数是批处理的,则生成 sample_shape 形状的样本批次。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源