torch.triangular_solve#
- torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)#
求解具有方阵(上三角或下三角可逆矩阵) 和多个右侧项 的方程组。
用符号表示,它求解 ,并假设 是方上三角(如果
upper= False 则为方下三角)并且对角线上没有零。torch.triangular_solve(b, A) 可以接受 2D 输入 b, A 或批量 2D 矩阵的输入。如果输入是批量的,则返回批量的输出 X
如果
A的对角线包含零或接近零的元素,并且unitriangular= False(默认值),或者如果输入矩阵条件较差,结果可能包含 NaN。支持 float, double, cfloat 和 cdouble 数据类型的输入。
警告
torch.triangular_solve()已弃用,推荐使用torch.linalg.solve_triangular(),并且将在未来的 PyTorch 版本中移除。torch.linalg.solve_triangular()的参数顺序已颠倒,并且不返回输入之一的副本。X = torch.triangular_solve(B, A).solution应替换为X = torch.linalg.solve_triangular(A, B)
- 参数
b (Tensor) – 多个右侧项,大小为 ,其中 是零个或多个批次维度
A (Tensor) – 输入的三角系数矩阵,大小为 ,其中 是零个或多个批次维度
upper (bool, optional) – 是上三角还是下三角。默认为
True。transpose (bool, optional) – 求解 op(A)X = b,其中当此标志为
True时 op(A) = A^T,当此标志为False时 op(A) = A。默认为False。unitriangular (bool, optional) – 是否为单位三角矩阵。如果为 True,则假定 的对角线元素为 1 并且不从 中引用。默认为
False。
- 关键字参数
out ((Tensor, Tensor), optional) – 用于写入输出的两个张量的元组。如果为 None 则忽略。默认为 None。
- 返回
一个命名元组 (solution, cloned_coefficient),其中 cloned_coefficient 是 的克隆,而 solution 是方程 (或根据关键字参数的方程变体)的解 。
示例
>>> A = torch.randn(2, 2).triu() >>> A tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]) >>> b = torch.randn(2, 3) >>> b tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) >>> torch.triangular_solve(b, A) torch.return_types.triangular_solve( solution=tensor([[ 1.7841, 2.9046, -2.5405], [ 1.9320, 0.9270, -1.2826]]), cloned_coefficient=tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]))