评价此页

torch.Tensor.masked_scatter_#

Tensor.masked_scatter_(mask, source)#

source 中的元素复制到 self 张量中 mask 为 True 的位置。从 source 的位置 0 开始,按顺序将 source 中的元素逐个复制到 self 中,每当 mask 为 True 时进行一次复制。 mask 的形状必须 可以广播 到底层张量的形状。 source 的元素数量应至少等于 mask 中 True 的数量。

参数
  • mask (BoolTensor) – 布尔掩码

  • source (Tensor) – 要从中复制的张量

注意

mask 操作的是 self 张量,而不是给定的 source 张量。

示例

>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
>>> mask = torch.tensor(
...     [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]],
...     dtype=torch.bool,
... )
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
>>> self.masked_scatter_(mask, source)
tensor([[0, 0, 0, 0, 1],
        [2, 3, 0, 4, 5]])