• 文档 >
  • 使用 TensorDict 预分配内存
快捷方式

使用 TensorDict 进行内存预分配

作者: Tom Begley

在本教程中,您将学习如何利用 TensorDict 中的内存预分配功能。

假设我们有一个函数,它返回一个 TensorDict

import torch
from tensordict.tensordict import TensorDict


def make_tensordict():
    return TensorDict({"a": torch.rand(3), "b": torch.rand(3, 4)}, [3])

也许我们想多次调用此函数,并将结果填充到一个 TensorDict 中。

N = 10
tensordict = TensorDict({}, batch_size=[N, 3])

for i in range(N):
    tensordict[i] = make_tensordict()

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

由于我们已经指定了 tensordictbatch_size,在循环的第一次迭代中,我们用空张量填充 tensordict,其第一个维度的大小为 N,其余维度由 make_tensordict 的返回值确定。在上面的示例中,我们为键 "a" 预分配了一个大小为 torch.Size([10, 3]) 的零数组,并为键 "b" 预分配了一个大小为 torch.Size([10, 3, 4]) 的数组。后续的循环迭代是就地写入的。因此,如果并非所有值都已填充,它们将获得默认值零。

让我们通过逐步分析上述循环来演示正在发生的情况。我们首先初始化一个空的 TensorDict

N = 10
tensordict = TensorDict({}, batch_size=[N, 3])
print(tensordict)
TensorDict(
    fields={
    },
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

第一次迭代后,tensordict 已经为 "a""b" 预先填充了张量。这些张量包含零,除了我们为其分配了随机值的第一行。

random_tensordict = make_tensordict()
tensordict[0] = random_tensordict

assert (tensordict[1:] == 0).all()
assert (tensordict[0] == random_tensordict).all()

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

在后续的迭代中,我们对预分配的张量进行就地更新。

a = tensordict["a"]
random_tensordict = make_tensordict()
tensordict[1] = random_tensordict

# the same tensor is stored under "a", but the values have been updated
assert tensordict["a"] is a
assert (tensordict[:2] != 0).all()

脚本总运行时间: (0 分 0.003 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源