评价此页

torch.repeat_interleave#

torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) Tensor#

Repeat elements of a tensor.

警告

This is different from torch.Tensor.repeat() but similar to numpy.repeat.

参数
  • input (Tensor) – 输入张量。

  • repeats (Tensor or int) – The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.

  • dim (int, optional) – The dimension along which to repeat values. By default, use the flattened input array, and return a flat output array.

关键字参数

output_size (int, optional) – Total output size for the given axis ( e.g. sum of repeats). If given, it will avoid stream synchronization needed to calculate output shape of the tensor.

返回

Repeated tensor which has the same shape as input, except along the given axis.

返回类型

张量

示例

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])
>>> y = torch.tensor([[1, 2], [3, 4]])
>>> torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
        [3, 4],
        [3, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3)
tensor([[1, 2],
        [3, 4],
        [3, 4]])

If the repeats is tensor([n1, n2, n3, …]), then the output will be tensor([0, 0, …, 1, 1, …, 2, 2, …, …]) where 0 appears n1 times, 1 appears n2 times, 2 appears n3 times, etc.

torch.repeat_interleave(repeats, *) Tensor

Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc.

参数

repeats (Tensor) – The number of repetitions for each element.

返回

Repeated tensor of size sum(repeats).

返回类型

张量

示例

>>> torch.repeat_interleave(torch.tensor([1, 2, 3]))
tensor([0, 1, 1, 2, 2, 2])