torch.multinomial#
- torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor #
返回一个张量,其中每行包含从张量
input
对应行中位于多项式(更严格的定义应为多元,详见torch.distributions.multinomial.Multinomial
)概率分布中采样的num_samples
个索引。注意
input
的行不必总和为一(在这种情况下,我们将值用作权重),但必须是非负的、有限的,并且总和非零。索引根据采样时间从左到右排序(首先采样的样本放在第一列)。
如果
input
是一个向量,out
是一个大小为num_samples
的向量。如果
input
是一个有 m 行的矩阵,out
是一个形状为 的矩阵。如果
replacement
为True
,则进行有放回采样。否则,进行无放回采样,这意味着当为某行采样一个索引后,该索引不能再次为该行采样。
注意
在无放回采样时,
num_samples
必须小于input
中非零元素的数量(如果input
是矩阵,则为每行中非零元素的最小数量)。- 参数
- 关键字参数
generator (
torch.Generator
, 可选) – 用于采样的伪随机数生成器out (Tensor, optional) – 输出张量。
示例
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> torch.multinomial(weights, 2) tensor([1, 2]) >>> torch.multinomial(weights, 5) # ERROR! RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1])