评价此页

torch.frombuffer#

torch.frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False) Tensor#

从实现 Python buffer 协议的对象创建一个一维 Tensor

跳过 buffer 中的前 offset 字节,并将剩余的原始字节解释为指定 dtype 类型、包含 count 个元素的⼀维张量。

注意,以下条件之一必须满足:

1. count 为正且非零,并且 buffer 中的总字节数大于 offset 加上 count 乘以 dtype 的大小(以字节为单位)。

2. count 为负数,并且 buffer 的长度(字节数)减去 offsetdtype 的大小(以字节为单位)的倍数。

返回的 tensor 和 buffer 共享相同的内存。对 tensor 的修改将反映在 buffer 中,反之亦然。返回的 tensor 不可调整大小。

注意

此函数会增加拥有共享内存的对象的引用计数。因此,此类内存不会在返回的 tensor 作用域结束前被释放。

警告

当传递一个实现了 buffer 协议但其数据不在 CPU 上的对象时,此函数的行为未定义。这样做很可能会导致段错误。

警告

此函数不会尝试推断 dtype(因此,它不是可选的)。传递与源不同的 dtype 可能会导致意外行为。

参数

buffer (object) – 一个公开 buffer 接口的 Python 对象。

关键字参数
  • dtype (torch.dtype) – 返回 tensor 的期望数据类型。

  • count (int, optional) – 要读取的期望元素数量。如果为负数,则读取所有元素(直到 buffer 末尾)。默认为 -1。

  • offset (int, optional) – 要在 buffer 开头跳过的字节数。默认为 0。

  • requires_grad (bool, optional) – 如果 autograd 应记录在返回的张量上的操作。默认值:False

示例

>>> import array
>>> a = array.array('i', [1, 2, 3])
>>> t = torch.frombuffer(a, dtype=torch.int32)
>>> t
tensor([ 1,  2,  3])
>>> t[0] = -1
>>> a
array([-1,  2,  3])

>>> # Interprets the signed char bytes as 32-bit integers.
>>> # Each 4 signed char elements will be interpreted as
>>> # 1 signed 32-bit integer.
>>> import array
>>> a = array.array('b', [-1, 0, 0, 0])
>>> torch.frombuffer(a, dtype=torch.int32)
tensor([255], dtype=torch.int32)