torch.nn.functional.one_hot#
- torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor #
接收形状为
(*)
的索引值的 LongTensor,并返回形状为(*, num_classes)
的张量。该张量在最后一个维度上,除了索引值对应的位置为 1 外,其余位置均为 0。另请参阅 Wikipedia 上的独热编码。
- 参数
tensor (LongTensor) – 任意形状的类别值。
num_classes (int, optional) – 总类别数。如果设置为 -1,则类别数将根据输入张量中最大的类别值推断(最大类别值 + 1)。默认为 -1
- 返回
返回一个 LongTensor,它比输入张量多一个维度,在该维度上,输入张量指示的索引位置为 1,其余位置为 0。
示例
>>> F.one_hot(torch.arange(0, 5) % 3) tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0]]) >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]) >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) tensor([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [1, 0, 0]], [[0, 1, 0], [0, 0, 1]]])