评价此页

torch.nn.functional.one_hot#

torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor#

接收形状为 (*) 的索引值的 LongTensor,并返回形状为 (*, num_classes) 的张量。该张量在最后一个维度上,除了索引值对应的位置为 1 外,其余位置均为 0。

另请参阅 Wikipedia 上的独热编码

参数
  • tensor (LongTensor) – 任意形状的类别值。

  • num_classes (int, optional) – 总类别数。如果设置为 -1,则类别数将根据输入张量中最大的类别值推断(最大类别值 + 1)。默认为 -1

返回

返回一个 LongTensor,它比输入张量多一个维度,在该维度上,输入张量指示的索引位置为 1,其余位置为 0。

示例

>>> F.one_hot(torch.arange(0, 5) % 3)
tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0]])
>>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0]])
>>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)
tensor([[[1, 0, 0],
         [0, 1, 0]],
        [[0, 0, 1],
         [1, 0, 0]],
        [[0, 1, 0],
         [0, 0, 1]]])