评价此页

Tensor.masked_scatter_() 将源中的元素复制到掩码为 True 的 self 张量中的位置。源中的元素按顺序一个接一个地复制到掩码为 True 的每个出现的 self 中。掩码的形状必须可以与基础张量的形状进行广播。源应包含的元素数量至少等于掩码中 True 的数量。

Tensor.masked_scatter_(mask, source)

将源中的元素复制到掩码为 True 的 self 张量中的位置。源中的元素按顺序一个接一个地复制到掩码为 True 的每个出现的 self 中。掩码的形状必须可以与基础张量的形状进行广播。源应包含的元素数量至少等于掩码中 True 的数量。

参数
  • mask (BoolTensor) – 布尔掩码

  • source (Tensor) – 要从中复制的张量

注意

掩码作用于 self 张量,而不是给定的源张量。

示例

>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
>>> mask = torch.tensor(
...     [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]],
...     dtype=torch.bool,
... )
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
>>> self.masked_scatter_(mask, source)
tensor([[0, 0, 0, 0, 1],
        [2, 3, 0, 4, 5]])