torch.linalg.inv#
- torch.linalg.inv(A, *, out=None) Tensor#
计算方阵的逆(如果存在)。如果矩阵不可逆,则抛出 RuntimeError。
令 为 或 ,对于矩阵 , 其逆矩阵 (如果存在) 定义为
其中 是 n 维单位矩阵。
当且仅当 可逆时,逆矩阵才存在。在这种情况下,逆是唯一的。
支持浮点 (float)、双精度浮点 (double)、复数浮点 (cfloat) 和复数双精度浮点 (cdouble) 数据类型。还支持矩阵批处理,如果 `A` 是一个矩阵批处理,则输出具有相同的批处理维度。
注意
当输入在 CUDA 设备上时,此函数会同步该设备与 CPU。有关不进行同步的此函数版本,请参阅
torch.linalg.inv_ex()。注意
如果可能,请考虑使用
torch.linalg.solve()来将矩阵左乘逆矩阵,因为linalg.solve(A, B) == linalg.inv(A) @ B # When B is a matrix
如果可能,始终优先使用
solve(),因为它比显式计算逆矩阵更快、更数值稳定。- 参数:
A (Tensor) – 形状为 (*, n, n) 的张量,其中 * 是零个或多个批次维度,由可逆矩阵组成。
- 关键字参数:
out (Tensor, optional) – 输出张量。如果为 None 则忽略。默认为 None。
- 抛出:
RuntimeError – 如果矩阵
A或A的任何批次中的矩阵不可逆。
示例
>>> A = torch.randn(4, 4) >>> Ainv = torch.linalg.inv(A) >>> torch.dist(A @ Ainv, torch.eye(4)) tensor(1.1921e-07) >>> A = torch.randn(2, 3, 4, 4) # Batch of matrices >>> Ainv = torch.linalg.inv(A) >>> torch.dist(A @ Ainv, torch.eye(4)) tensor(1.9073e-06) >>> A = torch.randn(4, 4, dtype=torch.complex128) # Complex matrix >>> Ainv = torch.linalg.inv(A) >>> torch.dist(A @ Ainv, torch.eye(4)) tensor(7.5107e-16, dtype=torch.float64)