注意
转到末尾 下载完整的示例代码。
使用 TensorDict 进行内存预分配¶
作者: Tom Begley
在本教程中,您将学习如何利用 TensorDict
中的内存预分配功能。
假设我们有一个返回 TensorDict
的函数
import torch
from tensordict.tensordict import TensorDict
def make_tensordict():
return TensorDict({"a": torch.rand(3), "b": torch.rand(3, 4)}, [3])
也许我们想多次调用此函数,并将结果填充到单个 TensorDict
中。
TensorDict(
fields={
a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10, 3]),
device=None,
is_shared=False)
由于我们指定了 tensordict
的 batch_size
,在循环的第一次迭代中,我们使用其第一个维度大小为 N
,其余维度由 make_tensordict
的返回值决定的空张量填充 tensordict
。在上面的示例中,我们为键 "a"
预分配了一个大小为 torch.Size([10, 3])
的零数组,为键 "b"
预分配了一个大小为 torch.Size([10, 3, 4])
的零数组。后续循环迭代是就地写入的。因此,如果并非所有值都已填充,则它们将获得默认值零。
让我们通过逐步遍历上述循环来演示正在发生的事情。我们首先初始化一个空的 TensorDict
。
N = 10
tensordict = TensorDict({}, batch_size=[N, 3])
print(tensordict)
TensorDict(
fields={
},
batch_size=torch.Size([10, 3]),
device=None,
is_shared=False)
第一次迭代后,tensordict
已预填充了 "a"
和 "b"
的张量。这些张量包含零,除了我们为其分配了随机值的第一个行。
random_tensordict = make_tensordict()
tensordict[0] = random_tensordict
assert (tensordict[1:] == 0).all()
assert (tensordict[0] == random_tensordict).all()
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10, 3]),
device=None,
is_shared=False)
后续迭代中,我们对预分配的张量进行就地更新。
脚本总运行时间: (0 分钟 0.003 秒)