torch.nn.functional.one_hot#
- torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor#
接受形状为
(*)的索引值的 LongTensor,并返回一个形状为(*, num_classes)的张量。该张量除了在最后一个维度索引与输入张量对应值匹配的位置为 1 外,其余位置均为 0。另请参阅 维基百科上的独热编码 (One-hot)。
- 参数:
tensor (LongTensor) – 任意形状的类别值。
num_classes (int, optional) – 总类别数。如果设置为 -1,则类别数将推断为输入张量中最大类别值加一。默认为 -1。
- 返回:
LongTensor,其比输入多一个维度,在最后一个维度由输入指示的索引处值为 1,其他位置为 0。
示例
>>> F.one_hot(torch.arange(0, 5) % 3) tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0]]) >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]) >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) tensor([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [1, 0, 0]], [[0, 1, 0], [0, 0, 1]]])