Ordinal¶
- class torchrl.modules.Ordinal(scores: Tensor)[source]¶
一个离散分布,用于学习从有限有序集合中采样。
它与 Categorical 分布形成对比,后者在其支持的原子上不施加任何邻近性或排序的概念。Ordinal 分布明确地编码了这些概念,这对于从连续集合学习离散采样非常有用。有关详细信息,请参阅 `Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`_ 的 §5。
注意
当你想要学习一个在通过离散化连续集合得到的有限集合上的分布时,这个类非常有用。
- 参数:
scores (torch.Tensor) – 一个形状为 […, N] 的张量,其中 N 是支持该分布的集合的大小。通常,这是参数化该分布的神经网络的输出。
示例
>>> num_atoms, num_samples = 5, 20 >>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom >>> torch.manual_seed(42) >>> logits = torch.ones((num_atoms), requires_grad=True) >>> optimizer = torch.optim.Adam([logits], lr=0.1) >>> >>> # Perform optimisation loop to minimise deviation from `mean` >>> for _ in range(20): >>> sampler = Ordinal(scores=logits) >>> samples = sampler.sample((num_samples,)) >>> # Define loss to encourage samples around the mean by penalising deviation from mean >>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples)) >>> loss.backward() >>> optimizer.step() >>> optimizer.zero_grad() >>> >>> sampler.probs tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...) >>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4) >>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms) torch.return_types.histogram( hist=tensor([ 24., 158., 478., 228., 112.]), bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000]))