torch.sort#
- torch.sort(input, dim=-1, descending=False, stable=False, *, out=None)#
对 `
input
` 张量沿着给定维度按值升序排序其元素。如果未指定 `
dim
`,则选择 `input` 的最后一个维度。如果 `
descending
` 为 `True
`,则元素按值降序排序。如果 `
stable
` 为 `True
`,则排序例程将变为稳定排序,保持等价元素的顺序。返回一个包含 (values, indices) 的命名元组,其中 values 是排序后的值,indices 是原始 input 张量中元素的索引。
- 参数
- 关键字参数
out (tuple, optional) – 可选的输出元组 (Tensor, LongTensor),用于作为输出缓冲区
示例
>>> x = torch.randn(3, 4) >>> sorted, indices = torch.sort(x) >>> sorted tensor([[-0.2162, 0.0608, 0.6719, 2.3332], [-0.5793, 0.0061, 0.6058, 0.9497], [-0.5071, 0.3343, 0.9553, 1.0960]]) >>> indices tensor([[ 1, 0, 2, 3], [ 3, 1, 0, 2], [ 0, 3, 1, 2]]) >>> sorted, indices = torch.sort(x, 0) >>> sorted tensor([[-0.5071, -0.2162, 0.6719, -0.5793], [ 0.0608, 0.0061, 0.9497, 0.3343], [ 0.6058, 0.9553, 1.0960, 2.3332]]) >>> indices tensor([[ 2, 0, 0, 1], [ 0, 1, 1, 2], [ 1, 2, 2, 0]]) >>> x = torch.tensor([0, 1] * 9) >>> x.sort() torch.return_types.sort( values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) >>> x.sort(stable=True) torch.return_types.sort( values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17]))