评价此页

Tensor 视图#

创建日期: 2020年2月28日 | 最后更新日期: 2025年2月26日

PyTorch 允许一个 tensor 成为现有 tensor 的一个 视图 (View)。视图 tensor 与其基 tensor 共享相同的底层数据。支持 视图 (View) 避免了显式的数据复制,从而使我们能够进行快速且内存高效的重塑、切片和逐元素操作。

例如,要获取现有 tensor t 的视图,可以调用 t.view(...)

>>> t = torch.rand(4, 4)
>>> b = t.view(2, 8)
>>> t.storage().data_ptr() == b.storage().data_ptr()  # `t` and `b` share the same underlying data.
True
# Modifying view tensor changes base tensor as well.
>>> b[0][0] = 3.14
>>> t[0][0]
tensor(3.14)

由于视图与其基 tensor 共享底层数据,因此如果您编辑视图中的数据,它也会反映在基 tensor 中。

通常,PyTorch 操作会返回一个新 tensor 作为输出,例如 add()。但在视图操作的情况下,输出是输入 tensor 的视图,以避免不必要的数据复制。创建视图时不会发生数据移动,视图 tensor 只是改变了解释相同数据的方式。获取连续 tensor 的视图可能会产生一个非连续的 tensor。用户应额外注意,连续性可能对性能有隐式影响。 transpose() 是一个常见的例子。

>>> base = torch.tensor([[0, 1],[2, 3]])
>>> base.is_contiguous()
True
>>> t = base.transpose(0, 1)  # `t` is a view of `base`. No data movement happened here.
# View tensors might be non-contiguous.
>>> t.is_contiguous()
False
# To get a contiguous tensor, call `.contiguous()` to enforce
# copying data when `t` is not contiguous.
>>> c = t.contiguous()

供参考,这是 PyTorch 中视图操作的完整列表

注意

通过索引访问 tensor 的内容时,PyTorch 遵循 Numpy 的行为,即基本索引返回视图,而高级索引返回副本。通过基本或高级索引进行的赋值是就地进行的。有关更多示例,请参见 Numpy 索引文档

还需要提及一些具有特殊行为的操作

  • reshape()reshape_as()flatten() 可以返回视图或新 tensor,用户代码不应依赖于它是视图还是新 tensor。

  • contiguous() 在输入 tensor 已经连续时返回**自身**,否则通过复制数据返回一个新的连续 tensor。

有关 PyTorch 内部实现的更详细的演练,请参阅 ezyang 关于 PyTorch Internals 的博客文章