torch.topk#
- torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)#
返回给定
input张量在给定维度上的k个最大元素。如果未指定 `
dim`,则选择 `input` 的最后一个维度。如果
largest为False,则返回 k 个最小的元素。返回一个命名元组 (values, indices),其中包含 input 张量在给定维度 dim 的每一行中最大的 k 个元素的 values 和 indices。
布尔选项
sorted如果为True,将确保返回的 k 个元素本身是排序的。注意
使用 torch.topk 时,相同元素的索引不保证稳定,并且在不同调用中可能有所不同。
- 参数:
- 关键字参数:
out (tuple, optional) – 可以选择提供的 (Tensor, LongTensor) 输出元组,用作输出缓冲区
示例
>>> x = torch.arange(1., 6.) >>> x tensor([ 1., 2., 3., 4., 5.]) >>> torch.topk(x, 3) torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))