torch.functional.lu#
- torch.functional.lu(*args, **kwargs)[source]#
计算矩阵或矩阵批次的 LU 分解
A。返回一个包含A的 LU 分解和透视(pivots)的元组。如果pivot设置为True,则进行部分主元法(partial pivoting)。警告
torch.lu()已弃用,推荐使用torch.linalg.lu_factor()和torch.linalg.lu_factor_ex()。torch.lu()将在 PyTorch 的未来版本中移除。LU, pivots, info = torch.lu(A, compute_pivots)应替换为LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)应替换为LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
注意
返回的批次中每个矩阵的置换矩阵(permutation matrix)表示为一个长度为
min(A.shape[-2], A.shape[-1])的 1-indexed 向量。pivots[i] == j表示在算法的第i步中,第i行与第j-1行进行了置换。带
pivot=False的 LU 分解不适用于 CPU,尝试这样做将引发错误。但是,带pivot=False的 LU 分解可用于 CUDA。如果
get_infos设置为True,此函数不会检查分解是否成功,因为分解的状态已包含在返回元组的第三个元素中。对于 CUDA 设备上大小小于或等于 32 的方形矩阵批次,由于 MAGMA 库中的一个 bug(请参见 magma issue 13),奇异矩阵的 LU 分解会被重复执行。
可以使用
torch.lu_unpack()推导出L、U和P。
警告
该函数的梯度仅在
A是满秩矩阵时才是有限的。这是因为 LU 分解仅在满秩矩阵上是可微的。此外,如果A接近于非满秩矩阵,则梯度在数值上是不稳定的,因为它依赖于 和 的计算。- 参数
A (Tensor) – 要分解的张量,大小为
pivot (bool, optional) – 是否要计算带部分主元法的 LU 分解,还是常规的 LU 分解。
pivot= False 在 CPU 上不受支持。默认为 True。get_infos (bool, optional) – 如果设置为
True,则返回一个 IntTensor。默认为Falseout (tuple, optional) – 可选的输出元组。如果
get_infos为True,则元组中的元素为 Tensor、IntTensor 和 IntTensor。如果get_infos为False,则元组中的元素为 Tensor 和 IntTensor。默认为None
- 返回
一个包含以下内容的张量元组:
factorization (Tensor): 分解结果,大小为
pivots (IntTensor): 透视(pivots)结果,大小为 。
pivots存储了所有的中间行交换。最终的置换perm可以通过对初始的 个元素的单位置换perm执行swap(perm[i], perm[pivots[i] - 1])来重构(对于i = 0, ..., pivots.size(-1) - 1),这本质上就是torch.lu_unpack()的作用。infos (IntTensor, optional): 如果
get_infos为True,则这是一个大小为 的张量,其中非零值表示矩阵或每个小批次的分解是否成功或失败。
- 返回类型
(Tensor, IntTensor, IntTensor (optional))
示例
>>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = torch.lu(A) >>> A_LU tensor([[[ 1.3506, 2.5558, -0.0816], [ 0.1684, 1.1551, 0.1940], [ 0.1193, 0.6189, -0.5497]], [[ 0.4526, 1.2526, -0.3285], [-0.7988, 0.7175, -0.9701], [ 0.2634, -0.9255, -0.3459]]]) >>> pivots tensor([[ 3, 3, 3], [ 3, 3, 3]], dtype=torch.int32) >>> A_LU, pivots, info = torch.lu(A, get_infos=True) >>> if info.nonzero().size(0) == 0: ... print('LU factorization succeeded for all samples!') LU factorization succeeded for all samples!