torch.Tensor.scatter_add_#
- Tensor.scatter_add_(dim, index, src) Tensor #
类似于
scatter_()
,将src
张量中的所有值加到index
张量指定的self
的索引处。对于src
中的每个值,它被加到self
的一个索引中,该索引通过其在src
中的索引(对于dimension != dim
)和index
中的相应值(对于dimension = dim
)来指定。对于一个 3-D 张量,
self
的更新方式为self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
self
,index
和src
必须具有相同的维度数。此外,对于所有维度d
,要求index.size(d) <= src.size(d)
,并且对于所有维度d != dim
,要求index.size(d) <= self.size(d)
。请注意,index
和src
不会广播。当index
为空时,我们始终返回原始张量,而不进行进一步的错误检查。注意
当在 CUDA 设备上使用张量时,此操作可能行为不确定。有关更多信息,请参阅 随机性。
注意
反向传播仅对
src.shape == index.shape
进行了实现。- 参数
示例
>>> src = torch.ones((2, 5)) >>> index = torch.tensor([[0, 1, 2, 0, 0]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) tensor([[1., 0., 0., 1., 1.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]]) >>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) tensor([[2., 0., 0., 1., 1.], [0., 2., 0., 0., 0.], [0., 0., 2., 1., 1.]])