注意
转到末尾 下载完整的示例代码。
Manipulating the keys of a TensorDict¶
作者: Tom Begley
在本教程中,您将学习如何使用和操作 TensorDict
中的键,包括获取和设置键、遍历键、操作嵌套值以及扁平化键。
设置和获取键¶
我们可以使用与 Python dict
相同的语法来设置和获取键。
import torch
from tensordict.tensordict import TensorDict
tensordict = TensorDict()
# set a key
a = torch.rand(10)
tensordict["a"] = a
# retrieve the value stored under "a"
assert tensordict["a"] is a
注意
与 Python dict
不同,TensorDict
中的所有键都必须是字符串。但是,正如我们将看到的,也可以使用字符串元组来操作嵌套值。
我们还可以使用 .get()
和 .set
方法来完成同样的事情。
tensordict = TensorDict()
# set a key
a = torch.rand(10)
tensordict.set("a", a)
# retrieve the value stored under "a"
assert tensordict.get("a") is a
与 dict
一样,我们可以为 get
提供一个默认值,以便在找不到请求的键时返回该值。
同样,与 dict
一样,我们可以使用 TensorDict.setdefault()
来获取特定键的值,如果找不到该键,则返回默认值,并将其值设置在 TensorDict
中。
删除键也可以通过与 Python dict
相同的方式实现,即使用 del
语句和选定的键。我们也可以等效地使用 TensorDict.del_
方法。
del tensordict["banana"]
此外,在使用 .set()
设置键时,我们可以使用关键字参数 inplace=True
进行原地更新,或者等效地使用 .set_()
方法。
tensordict.set("a", torch.zeros(10), inplace=True)
# all the entries of the "a" tensor are now zero
assert (tensordict.get("a") == 0).all()
# but it's still the same tensor as before
assert tensordict.get("a") is a
# we can achieve the same with set_
tensordict.set_("a", torch.ones(10))
assert (tensordict.get("a") == 1).all()
assert tensordict.get("a") is a
重命名键¶
要重命名键,只需使用 TensorDict.rename_key_
方法。存储在原始键下的值将保留在 TensorDict
中,但键将被更改为指定的键。
tensordict.rename_key_("a", "b")
assert tensordict.get("b") is a
print(tensordict)
TensorDict(
fields={
b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
更新多个值¶
可以使用 TensorDict.update
方法使用另一个 TensorDict`
或 dict
来更新 TensorDict
。已存在的键将被覆盖,不存在的键将被创建。
tensordict = TensorDict({"a": torch.rand(10), "b": torch.rand(10)}, [10])
tensordict.update(TensorDict({"a": torch.zeros(10), "c": torch.zeros(10)}, [10]))
assert (tensordict["a"] == 0).all()
assert (tensordict["b"] != 0).all()
assert (tensordict["c"] == 0).all()
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
c: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)
嵌套值¶
TensorDict
的值本身可以是 TensorDict
。我们可以在实例化时添加嵌套值,方法是直接添加 TensorDict
,或者使用嵌套字典。
# creating nested values with a nested dict
nested_tensordict = TensorDict(
{"a": torch.rand(2, 3), "double_nested": {"a": torch.rand(2, 3)}}, [2, 3]
)
# creating nested values with a TensorDict
tensordict = TensorDict({"a": torch.rand(2), "nested": nested_tensordict}, [2])
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
要访问这些嵌套值,我们可以使用字符串元组。例如:
double_nested_a = tensordict["nested", "double_nested", "a"]
nested_a = tensordict.get(("nested", "a"))
同样,我们可以使用字符串元组设置嵌套值。
tensordict["nested", "double_nested", "b"] = torch.rand(2, 3)
tensordict.set(("nested", "b"), torch.rand(2, 3))
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
遍历 TensorDict 的内容¶
我们可以使用 .keys()
方法遍历 TensorDict
的键。
a
nested
默认情况下,这将只遍历 TensorDict
中的顶级键。但是,可以通过关键字参数 include_nested=True
来递归遍历 TensorDict
中的所有键。这将递归遍历任何嵌套的 TensorDict 中的所有键,并将嵌套的键作为字符串元组返回。
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'double_nested')
('nested', 'b')
nested
如果您只想遍历对应于 Tensor
值的键,您可以额外指定 leaves_only=True
。
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'b')
与 dict
类似,还有接受相同关键字参数的 .values
和 .items
方法。
a is a Tensor
nested is a TensorDict
('nested', 'a') is a Tensor
('nested', 'double_nested') is a TensorDict
('nested', 'double_nested', 'a') is a Tensor
('nested', 'double_nested', 'b') is a Tensor
('nested', 'b') is a Tensor
检查键是否存在¶
要检查键是否存在于 TensorDict
中,请使用 in
运算符结合 .keys()
。
注意
key in tensordict.keys()
的执行会进行高效的 dict
键查找(在嵌套情况下,在每个级别上递归查找),因此当 TensorDict
中有大量键时,性能不会受到负面影响。
assert "a" in tensordict.keys()
# to check for nested keys, set include_nested=True
assert ("nested", "a") in tensordict.keys(include_nested=True)
assert ("nested", "banana") not in tensordict.keys(include_nested=True)
扁平化和反扁平化嵌套键¶
我们可以使用 .flatten_keys()
方法来扁平化具有嵌套值的 TensorDict
。
print(tensordict, end="\n\n")
print(tensordict.flatten_keys(separator="."))
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
给定一个已被扁平化的 TensorDict
,可以使用 .unflatten_keys()
方法再次将其反扁平化。
flattened_tensordict = tensordict.flatten_keys(separator=".")
print(flattened_tensordict, end="\n\n")
print(flattened_tensordict.unflatten_keys(separator="."))
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
这在操作 torch.nn.Module
的参数时特别有用,因为我们可以最终得到一个结构模仿模块结构的 TensorDict
。
import torch.nn as nn
module = nn.Sequential(
nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 10)),
nn.Linear(10, 1),
)
params = TensorDict(dict(module.named_parameters()), []).unflatten_keys()
print(params)
TensorDict(
fields={
0: TensorDict(
fields={
0: TensorDict(
fields={
bias: Parameter(shape=torch.Size([50]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([50, 100]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
1: TensorDict(
fields={
bias: Parameter(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([10, 50]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
1: TensorDict(
fields={
bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([1, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
选择和排除键¶
我们可以通过使用 TensorDict.select
来获取一个包含部分键的新 TensorDict
,它返回一个只包含指定键的新 TensorDict
,或者使用 :meth: TensorDict.exclude <tensordict.TensorDict.exclude>,它返回一个省略了指定键的新 TensorDict
。
print("Select:")
print(tensordict.select("a", ("nested", "a")), end="\n\n")
print("Exclude:")
print(tensordict.exclude(("nested", "b"), ("nested", "double_nested")))
Select:
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
Exclude:
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
脚本总运行时间: (0 分钟 0.009 秒)