torch.linalg.pinv#
- torch.linalg.pinv(A, *, atol=None, rtol=None, hermitian=False, out=None) Tensor #
计算矩阵的伪逆(摩尔-彭罗斯逆)。
伪逆可以 代数上定义,但通过 SVD 来理解它在计算上更方便。
支持浮点 (float)、双精度浮点 (double)、复数浮点 (cfloat) 和复数双精度浮点 (cdouble) 数据类型。还支持矩阵批处理,如果 `A` 是一个矩阵批处理,则输出具有相同的批处理维度。
如果
hermitian
= True,则假设A
是厄米特(复数)或对称(实数)的,但这不会在内部进行检查。相反,在计算中仅使用矩阵的下三角部分。奇异值(或特征值的范数,当
hermitian
= True 时)低于 的阈值将被视为零并在计算中被忽略,其中 是最大的奇异值(或特征值)。如果未指定
rtol
,并且A
是维度为 (m, n) 的矩阵,则相对容差设置为 ,其中 是A
的 dtype 的 epsilon 值(参见finfo
)。如果未指定rtol
且atol
被指定为大于零,则rtol
将被设置为零。如果
atol
或rtol
是一个torch.Tensor
,则其形状必须可广播到A
的奇异值形状,这些奇异值由torch.linalg.svd()
返回。注意
如果
hermitian
= False,此函数使用torch.linalg.svd()
;如果hermitian
= True,则使用torch.linalg.eigh()
。对于 CUDA 输入,此函数会将该设备与 CPU 同步。注意
如果可能,请考虑使用
torch.linalg.lstsq()
将伪逆乘以左侧,因为torch.linalg.lstsq(A, B).solution == A.pinv() @ B
如果可能,始终优先使用
lstsq()
,因为它比显式计算伪逆更快且数值更稳定。注意
此函数有一个与 NumPy 兼容的版本 linalg.pinv(A, rcond, hermitian=False)。但是,使用位置参数
rcond
已弃用,建议使用rtol
。警告
此函数内部使用
torch.linalg.svd()
(或在hermitian
= True 时使用torch.linalg.eigh()
),因此其导数与这些函数导数存在相同的问题。有关更多详细信息,请参阅torch.linalg.svd()
和torch.linalg.eigh()
中的警告。- 参数
- 关键字参数
示例
>>> A = torch.randn(3, 5) >>> A tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132], [-1.1143, -0.3662, 0.3042, 1.6374, -0.9294], [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) >>> torch.linalg.pinv(A) tensor([[ 0.0600, -0.1933, -0.2090], [-0.0903, -0.0817, -0.4752], [-0.7124, -0.1631, -0.2272], [ 0.1356, 0.3933, -0.5023], [-0.0308, -0.1725, -0.5216]]) >>> A = torch.randn(2, 6, 3) >>> Apinv = torch.linalg.pinv(A) >>> torch.dist(Apinv @ A, torch.eye(3)) tensor(8.5633e-07) >>> A = torch.randn(3, 3, dtype=torch.complex64) >>> A = A + A.T.conj() # creates a Hermitian matrix >>> Apinv = torch.linalg.pinv(A, hermitian=True) >>> torch.dist(Apinv @ A, torch.eye(3)) tensor(1.0830e-06)