torch.linalg.multi_dot#
- torch.linalg.multi_dot(tensors, *, out=None)#
高效地通过重新排序乘法来计算两个或多个矩阵的乘积,以执行最少的算术运算。
支持 float, double, cfloat 和 cdouble 数据类型的输入。此函数不支持批处理输入。
在
tensors
中的每个张量都必须是 2D 的,除了第一个和最后一个张量,它们可以是 1D 的。如果第一个张量是形状为 (n,) 的 1D 向量,则将其视为形状为 (1, n) 的行向量;类似地,如果最后一个张量是形状为 (n,) 的 1D 向量,则将其视为形状为 (n, 1) 的列向量。如果第一个和最后一个张量是矩阵,则输出将是矩阵。但是,如果其中一个是 1D 向量,则输出将是 1D 向量。
与 numpy.linalg.multi_dot 的区别
与 numpy.linalg.multi_dot 不同,第一个和最后一个张量必须是 1D 或 2D,而 NumPy 允许它们是 nD。
警告
此函数不执行广播。
注意
此函数通过在计算最佳矩阵乘法顺序后链式调用
torch.mm()
来实现。注意
形状为 (a, b) 和 (b, c) 的两个矩阵相乘的成本为 a * b * c。给定形状分别为 (10, 100)、(100, 5) 和 (5, 50) 的矩阵 A、B、C,我们可以计算不同乘法顺序的成本如下:
在这种情况下,先计算 A 和 B 的乘积,然后再乘以 C 的速度是后者的 10 倍。
- 参数
tensors (Sequence[Tensor]) – 要相乘的两个或多个张量。第一个和最后一个张量可以是 1D 或 2D。所有其他张量都必须是 2D。
- 关键字参数
out (Tensor, optional) – 输出张量。如果为 None 则忽略。默认为 None。
示例
>>> from torch.linalg import multi_dot >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])]) tensor(8) >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])]) tensor([8]) >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])]) tensor([[8]]) >>> A = torch.arange(2 * 3).view(2, 3) >>> B = torch.arange(3 * 2).view(3, 2) >>> C = torch.arange(2 * 2).view(2, 2) >>> multi_dot((A, B, C)) tensor([[ 26, 49], [ 80, 148]])