评价此页

torch.trace#

torch.trace(input) Tensor#

返回输入二维矩阵对角线元素之和。

示例

>>> x = torch.arange(1., 10.).view(3, 3)
>>> x
tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.]])
>>> torch.trace(x)
tensor(15.)