OneHotCategorical¶
- class torchrl.modules.OneHotCategorical(logits: torch.Tensor | None = None, probs: torch.Tensor | None = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs)[source]¶
独热(One-hot)分类分布。
此类行为与 torch.distributions.Categorical 完全相同,但它读取和生成离散张量的一次热编码。
- 参数:
logits (torch.Tensor) – 事件的对数概率(未归一化)
probs (torch.Tensor) – 事件的概率
grad_method (ReparamGradientStrategy, optional) –
用于收集重参数化样本的策略。
ReparamGradientStrategy.PassThrough
将使用 softmax 值对数概率作为样本梯度的代理来计算样本梯度。使用 softmax 值作为样本梯度的代理来计算样本梯度。
ReparamGradientStrategy.RelaxedOneHot
将使用torch.distributions.RelaxedOneHot
从分布中采样。
示例
>>> torch.manual_seed(0) >>> logits = torch.randn(4) >>> dist = OneHotCategorical(logits=logits) >>> print(dist.rsample((3,))) tensor([[1., 0., 0., 0.], [0., 0., 0., 1.], [1., 0., 0., 0.]])
- log_prob(value: torch.Tensor) torch.Tensor [source]¶
返回在 value 处评估的概率密度/质量函数的对数。
- 参数:
value (Tensor) –
- rsample(sample_shape: torch.Size | Sequence = None) torch.Tensor [source]¶
生成 sample_shape 形状的重参数化样本,如果分布参数是批处理的,则生成 sample_shape 形状的重参数化样本批次。
- sample(sample_shape: torch.Size | Sequence | None = None) torch.Tensor [source]¶
生成 sample_shape 形状的样本,如果分布参数是批处理的,则生成 sample_shape 形状的样本批次。