MaskedOneHotCategorical¶
- class torchrl.modules.MaskedOneHotCategorical(logits: torch.Tensor | None = None, probs: torch.Tensor | None = None, mask: torch.Tensor = None, indices: torch.Tensor = None, neg_inf: float = - inf, padding_value: int | None = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough)[source]¶
MaskedCategorical 分布。
参考: https://tensorflowcn.cn/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical
- 参数:
logits (torch.Tensor) – 事件的对数概率(未归一化)
probs (torch.Tensor) – 事件概率。如果提供了该参数,则与屏蔽项对应的概率将被置零,并且概率将在其最后一个维度上重新归一化。
- 关键字参数:
mask (torch.Tensor) – 与
logits
/probs
具有相同形状的布尔掩码,其中False
条目是要屏蔽的。或者,如果sparse_mask
为 True,它将表示分布中的有效索引列表。与indices
互斥。indices (torch.Tensor) – 一个密集索引张量,表示必须考虑哪些动作。与
mask
互斥。neg_inf (
float
, 可选) – 分配给无效(超出掩码)索引的对数概率值。默认为 -inf。padding_value – 在 `sparse_mask == True` 时掩码张量中的填充值,`padding_value` 将被忽略。
grad_method (ReparamGradientStrategy, 可选) –
用于收集重参数化样本的策略。
ReparamGradientStrategy.PassThrough
将使用 softmax 值对数概率作为样本梯度代理来计算样本梯度。通过使用 softmax 值对数概率作为样本梯度代理来计算样本梯度。
ReparamGradientStrategy.RelaxedOneHot
将使用torch.distributions.RelaxedOneHot
从分布中采样。
示例
>>> torch.manual_seed(0) >>> logits = torch.randn(4) / 100 # almost equal probabilities >>> mask = torch.tensor([True, False, True, True]) >>> dist = MaskedOneHotCategorical(logits=logits, mask=mask) >>> sample = dist.sample((10,)) >>> print(sample) # no `1` in the sample tensor([[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]]) >>> print(dist.log_prob(sample)) tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203, -1.0831, -1.1203, -1.1203]) >>> sample_non_valid = torch.zeros_like(sample) >>> sample_non_valid[..., 1] = 1 >>> print(dist.log_prob(sample_non_valid)) tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]) >>> # with probabilities >>> prob = torch.ones(10) >>> prob = prob / prob.sum() >>> mask = torch.tensor([False] + 9 * [True]) # first outcome is masked >>> dist = MaskedOneHotCategorical(probs=prob, mask=mask) >>> s = torch.arange(10) >>> s = torch.nn.functional.one_hot(s, 10) >>> print(dist.log_prob(s)) tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972])
- rsample(sample_shape: torch.Size | Sequence = None) torch.Tensor [source]¶
生成 sample_shape 形状的重参数化样本,如果分布参数是批处理的,则生成 sample_shape 形状的重参数化样本批次。
- sample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor [source]¶
生成 sample_shape 形状的样本,如果分布参数是批处理的,则生成 sample_shape 形状的样本批次。