广播语义#
创建日期: 2017年4月27日 | 最后更新日期: 2025年9月29日
许多 PyTorch 操作都支持 NumPy 的广播语义。有关详细信息,请参阅 https://numpy.com.cn/doc/stable/user/basics.broadcasting.html。
简而言之,如果一个 PyTorch 操作支持广播,那么它的 Tensor 参数可以自动扩展到具有相同的大小(而无需复制数据)。
通用语义#
如果满足以下规则,则两个张量是“可广播的”:
当从最后一个维度开始迭代维度大小时,维度大小必须相等,其中一个大小为 1,或者其中一个不存在。
例如
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)
>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because the 0-sized dimension of x
# does not match the 2-sized dimension of y.
# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
如果两个张量 x 和 y 是“可广播的”,则结果张量的大小计算如下:
如果
x和y的维度数量不相等,则在维度较少的张量的维度前加上 1,使其长度相等。然后,对于每个维度大小,结果维度大小是
x和y在该维度上的大小的最大值。
例如
# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
原地语义#
一个复杂之处在于,原地操作不允许原地张量因广播而改变形状。
例如
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])
# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.
向后兼容性#
PyTorch 的早期版本允许某些逐元素函数在具有不同形状的张量上执行,只要每个张量的元素数量相等。然后,逐元素操作将通过将每个张量视为一维来执行。PyTorch 现在支持广播,并且“一维”逐元素行为被认为已弃用,在张量不可广播但元素数量相同的情况下会生成 Python 警告。
请注意,引入广播可能会在两个张量形状不相同但可广播且元素数量相同的情况下导致向后不兼容的更改。例如:
>>> torch.add(torch.ones(4,1), torch.randn(4))
以前会生成大小为:torch.Size([4,1]) 的 Tensor,但现在生成大小为:torch.Size([4,4]) 的 Tensor。为了帮助识别代码中可能存在的由广播引入的向后不兼容的场景,您可以将 torch.utils.backcompat.broadcast_warning.enabled 设置为 True,这将在此类情况下生成 Python 警告。
例如
>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.