torch.Tensor.masked_scatter_#
- Tensor.masked_scatter_(mask, source)#
将
source
中的元素复制到self
张量中mask
为 True 的位置。从source
的位置 0 开始,按顺序将source
中的元素逐个复制到self
中,每当mask
为 True 时进行一次复制。mask
的形状必须 可以广播 到底层张量的形状。source
的元素数量应至少等于mask
中 True 的数量。- 参数
mask (BoolTensor) – 布尔掩码
source (Tensor) – 要从中复制的张量
注意
mask
操作的是self
张量,而不是给定的source
张量。示例
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) >>> mask = torch.tensor( ... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], ... dtype=torch.bool, ... ) >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> self.masked_scatter_(mask, source) tensor([[0, 0, 0, 0, 1], [2, 3, 0, 4, 5]])