评价此页

torch.multinomial#

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor#

返回一个张量,其中每行包含从张量 input 对应行中位于多项式(更严格的定义应为多元,详见 torch.distributions.multinomial.Multinomial)概率分布中采样的 num_samples 个索引。

注意

input 的行不必总和为一(在这种情况下,我们将值用作权重),但必须是非负的、有限的,并且总和非零。

索引根据采样时间从左到右排序(首先采样的样本放在第一列)。

如果 input 是一个向量,out 是一个大小为 num_samples 的向量。

如果 input 是一个有 m 行的矩阵,out 是一个形状为 (m×num_samples)(m \times \text{num\_samples}) 的矩阵。

如果 replacementTrue,则进行有放回采样。

否则,进行无放回采样,这意味着当为某行采样一个索引后,该索引不能再次为该行采样。

注意

在无放回采样时,num_samples 必须小于 input 中非零元素的数量(如果 input 是矩阵,则为每行中非零元素的最小数量)。

参数
  • input (Tensor) – 包含概率的输入张量

  • num_samples (int) – 要采样的数量

  • replacement (bool, 可选) – 是否进行有放回采样

关键字参数
  • generator (torch.Generator, 可选) – 用于采样的伪随机数生成器

  • out (Tensor, optional) – 输出张量。

示例

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 5) # ERROR!
RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])