torch.einsum#
- torch.einsum(equation, *operands) Tensor [source]#
根据爱因斯坦求和约定指定的下标对输入
operands
的元素进行求和。Einsum 允许通过基于爱因斯坦求和约定的简写格式(由
equation
指定)来计算许多常见的 N 维线性代数数组运算。此格式的详细信息将在下面描述,但总体的思想是为输入operands
的每个维度标记一个下标,并定义哪些下标是输出的一部分。然后,通过对那些下标不包含在输出中的维度进行求和来计算输出。例如,矩阵乘法可以使用 einsum 计算为 torch.einsum(“ij,jk->ik”, A, B)。在此,j 是求和下标,i 和 k 是输出下标(有关为什么的更多详细信息,请参阅下面的部分)。Equation
The
equation
string specifies the subscripts (letters in [a-zA-Z]) for each dimension of the inputoperands
in the same order as the dimensions, separating subscripts for each operand by a comma (‘,’), e.g. ‘ij,jk’ specify subscripts for two 2D operands. The dimensions labeled with the same subscript must be broadcastable, that is, their size must either match or be 1. The exception is if a subscript is repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that appear exactly once in theequation
will be part of the output, sorted in increasing alphabetical order. The output is computed by multiplying the inputoperands
element-wise, with their dimensions aligned based on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.可以选择通过在方程末尾添加箭头(“->”)来显式定义输出下标,并在后面跟上输出的下标。例如,以下方程计算矩阵乘法的转置:“ij,jk->ki”。输出下标必须至少在某个输入操作数中出现一次,并且在输出中最多出现一次。
省略号(“…”)可以用来替代下标,以广播省略号覆盖的维度。每个输入操作数最多可以包含一个省略号,该省略号将覆盖未被下标覆盖的维度。例如,对于一个具有 5 个维度的输入操作数,方程 “ab…c” 中的省略号将覆盖第三和第四个维度。省略号不需要覆盖操作数(
operands
)之间相同数量的维度,但省略号的“形状”(由它们覆盖的维度的大小)必须能够一起广播。如果输出没有使用箭头(“->”)表示法显式定义,省略号将出现在输出的最前面(最左边的维度),然后是输入操作数中仅出现一次的下标标签。例如,以下方程实现了批量矩阵乘法 “…ij,…jk”。最后几点说明:方程可以在不同元素(下标、省略号、箭头和逗号)之间包含空格,但像 “…” 这样的写法是无效的。空字符串 “” 对于标量操作数是有效的。
注意
torch.einsum
处理省略号(“…”)的方式与 NumPy 不同,它允许省略号覆盖的维度被求和,也就是说,省略号不一定需要成为输出的一部分。注意
请安装 opt-einsum(https://optimized-einsum.readthedocs.io/en/stable/),以便使用更高效的 einsum。您可以像这样在安装 torch 时安装:pip install torch[opt-einsum],或者单独安装:pip install opt-einsum。
如果 opt-einsum 可用,此函数将通过我们的 opt_einsum 后端(
torch.backends.opt_einsum
)优化收缩顺序,从而自动加速计算和/或减少内存消耗(我知道 “_” 和 “-” 之间存在混淆)。当输入至少有三个时,就会发生这种优化,因为否则顺序无关紧要。请注意,找到“最佳”路径是一个 NP 难问题,因此,opt-einsum 依赖于不同的启发式方法来获得接近最优的结果。如果 opt-einsum 不可用,默认顺序是从左到右收缩。要绕过此默认行为,请添加以下内容以禁用 opt_einsum 并跳过路径计算:
torch.backends.opt_einsum.enabled = False
要指定 opt_einsum 计算收缩路径所使用的策略,请添加以下行:
torch.backends.opt_einsum.strategy = 'auto'
。默认策略是“auto”,我们也支持“greedy”和“optimal”。请注意,“optimal”策略的运行时复杂度是输入数量的阶乘!有关更多详细信息,请参阅 opt_einsum 文档(https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)。注意
从 PyTorch 1.10 开始,
torch.einsum()
还支持子列表格式(请参见下面的示例)。在此格式中,每个操作数的下标由子列表指定,即范围在 [0, 52) 内的整数列表。这些子列表紧跟在它们的操作数之后,并且可以在输入末尾出现一个额外的子列表来指定输出的下标,例如:torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])。Python 的 Ellipsis 对象可以包含在子列表中,以启用如上文“方程”部分所述的广播。示例
>>> # trace >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.2104) >>> # diagonal >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545]) >>> # outer product >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]]) >>> # batch matrix multiplication >>> As = torch.randn(3, 2, 5) >>> Bs = torch.randn(3, 5, 4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # with sublist format and ellipsis >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # batch permute >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) >>> # equivalent to torch.nn.functional.bilinear >>> A = torch.randn(3, 5, 4) >>> l = torch.randn(2, 5) >>> r = torch.randn(2, 4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])