评价此页

torch.matmul#

torch.matmul(input, other, *, out=None) Tensor#

两个张量的矩阵乘积。

行为取决于张量的维度,如下所示:

  • 如果两个张量都是一维的,则返回点积(标量)。

  • 如果两个参数都是二维的,则返回矩阵-矩阵乘积。

  • 如果第一个参数是一维的,第二个参数是二维的,则在进行矩阵乘法时,其维度前面会加上一个 1。矩阵乘法完成后,会移除添加的维度。

  • 如果第一个参数是二维的,第二个参数是一维的,则返回矩阵-向量乘积。

  • 如果两个参数至少为一维,并且至少有一个参数是 N 维(N > 2),则返回批处理矩阵乘法。如果第一个参数是一维的,则在进行批处理矩阵乘法时,其维度前面会加上一个 1,并在之后移除。如果第二个参数是一维的,则在进行批处理矩阵乘法时,其维度后面会加上一个 1,并在之后移除。

    每个参数的前 N-2 个维度(批处理维度)将进行广播(因此必须是可广播的)。最后的 2 个维度(矩阵维度)按照矩阵-矩阵乘积的方式处理。

    例如,如果 input 是一个 (j×1×n×m)(j \times 1 \times n \times m) 张量,而 other 是一个 (k×m×p)(k \times m \times p) 张量,则批处理维度为 (j×1)(j \times 1)(k)(k),矩阵维度为 (n×m)(n \times m)(m×p)(m \times p)out 将是一个 (j×k×n×p)(j \times k \times n \times p) 张量。

此操作支持具有稀疏布局的参数。特别是,矩阵-矩阵(两个参数都是二维的)支持稀疏参数,其限制与torch.mm()相同。

警告

稀疏支持是测试版功能,某些布局/数据类型/设备组合可能不支持,或可能不支持自动求导。如果您发现缺少功能,请提交功能请求。

此操作符支持TensorFloat32

在某些 ROCm 设备上,当使用 float16 输入时,此模块将对反向传播使用不同精度

注意

此函数的 out 参数的一维点积版本不支持。

参数
  • input (Tensor) – 要相乘的第一个张量

  • other (Tensor) – 要相乘的第二个张量

关键字参数

out (Tensor, optional) – 输出张量。

示例

>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])
>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])