评价此页

torch.linalg.tensorinv#

torch.linalg.tensorinv(A, ind=2, *, out=None) Tensor#

计算 torch.tensordot() 的乘法逆。

如果 mA 的前 ind 维的乘积,n 是剩余维的乘积,那么该函数期望 mn 相等。如果相等,它会计算一个张量 X,使得 tensordot(A, X, ind) 是维度 m 上的单位矩阵。 X 的形状将与 A 相同,但前 ind 维会移到最后。

X.shape == A.shape[ind:] + A.shape[:ind]

支持浮点、双精度、复浮点和复双精度数据类型的输入。

注意

A 是一个 2 维张量且 ind= 1 时,此函数计算 A 的(乘法)逆(请参阅 torch.linalg.inv())。

注意

如果可能,请考虑使用 torch.linalg.tensorsolve() 来将张量左乘张量逆,因为

linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B)  # When B is a tensor with shape A.shape[:B.ndim]

始终优先使用 tensorsolve(),因为它比显式计算伪逆更快且数值更稳定。

另请参阅

torch.linalg.tensorsolve() 计算 torch.tensordot(tensorinv(A), B)

参数
  • A (Tensor) – 要求逆的张量。其形状必须满足 prod(A.shape[:ind]) == prod(A.shape[ind:])

  • ind (int) – 计算 torch.tensordot() 逆的索引。默认值:2

关键字参数

out (Tensor, optional) – 输出张量。如果为 None 则忽略。默认为 None

引发

RuntimeError – 如果重塑的 A 不可逆,或者前 ind 维的乘积不等于剩余维的乘积。

示例

>>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3))
>>> Ainv = torch.linalg.tensorinv(A, ind=2)
>>> Ainv.shape
torch.Size([8, 3, 4, 6])
>>> B = torch.randn(4, 6)
>>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B))
True

>>> A = torch.randn(4, 4)
>>> Atensorinv = torch.linalg.tensorinv(A, ind=1)
>>> Ainv = torch.linalg.inv(A)
>>> torch.allclose(Atensorinv, Ainv)
True