TensorDictModule¶
- class tensordict.nn.TensorDictModule(*args, **kwargs)¶
TensorDictModule 是一个 Python 包装器,用于读取和写入 TensorDict 的
nn.Module
。- 参数:
module (Callable[[Any], Any]) – 一个可调用对象,通常是
torch.nn.Module
,用于将输入映射到输出参数空间。其 forward 方法可以返回单个张量、张量元组甚至字典。在后一种情况下,TensorDictModule
的输出键将用于填充输出 tensordict(即out_keys
中的键应存在于module
forward 方法返回的字典中)。in_keys (iterable of NestedKeys, Dict[NestedStr, str]) – 从输入 tensordict 读取并传递给模块的键。如果它包含多个元素,则值将按 in_keys 可迭代对象给定的顺序传递。如果
in_keys
是一个字典,其键必须对应于 tensordict 中要读取的键,其值必须与函数签名中的关键字参数名称匹配。如果 out_to_in_map 为True
,则映射会反转,以便键对应于函数签名中的关键字参数。out_keys (iterable of str) – 写入输入 tensordict 的键。out_keys 的长度必须与嵌入式模块返回的张量数量匹配。使用“_”作为键可避免将张量写入输出。
- 关键字参数:
out_to_in_map (bool, optional) – 如果为
True
(默认值),则 in_keys 的读取方式就好像键是forward()
方法的参数键,值是输入TensorDict
中的键。如果为False
,则键被视为输入键,值被视为方法的参数键。inplace (bool or string, optional) –
如果为
True
(默认值),则模块的输出将写入提供给forward()
方法的 tensordict。如果为False
,则会创建一个具有空批次大小且没有设备的新TensorDict
。如果为"empty"
,则将使用empty()
来创建输出 tensordict。注意
如果
inplace=False
并且传递给模块的 tensordict 是TensorDictBase
的子类而不是TensorDict
,则输出仍将是TensorDict
实例。其批次大小将为空,并且没有设备。设置为"empty"
以获得相同的TensorDictBase
子类型、相同的批次大小和设备。使用运行时tensordict_out
(见下文)可以更精细地控制输出。注意
如果
inplace=False
并且tensordict_out
被传递给forward()
方法,则tensordict_out
将优先。这是获取 tensordict_out 的方式,当传递给模块的 tensordict 是TensorDictBase
的子类而不是TensorDict
时,输出仍将是TensorDict
实例。method (str, optional) – 要在模块中调用的方法(如果有)。默认为 __call__。
method_kwargs (Dict[str, Any], optional) – 要传递给模块正在调用的方法的其他关键字参数。
strict (bool, optional) – 如果为
True
,则如果输入 tensordict 中缺少任何输入,模块将引发异常。否则,将使用 None 值作为占位符。默认为False
。get_kwargs (dict[str, Any], optional) – 要传递给
get()
方法的其他关键字参数。这在处理不规则张量时尤其有用(参见get()
)。默认为{}
。
嵌入神经网络到 TensorDictModule 中只需要指定输入和输出键。TensorDictModule 支持函数式和常规的
nn.Module
对象。在函数式情况下,必须指定 'params'(以及 'buffers')关键字参数。示例
>>> from tensordict import TensorDict >>> # one can wrap regular nn.Module >>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"]) >>> input = torch.ones(2, 3, 128) >>> tgt = torch.zeros(2, 3, 128) >>> data = TensorDict({"input": input, "tgt": tgt}, batch_size=[2, 3]) >>> data = module(data) >>> print(data) TensorDict( fields={ input: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False), out: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False), tgt: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)
我们也可以直接传递张量。
示例
>>> out = module(input, tgt) >>> assert out.shape == input.shape >>> # we can also wrap regular functions >>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")]) >>> module(TensorDict({("input", "x"): torch.zeros(())}, batch_size=[])) TensorDict( fields={ input: TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), output: TensorDict( fields={ x+1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), x-1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
我们可以使用 TensorDictModule 来填充 tensordict。
示例
>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"]) >>> print(module(TensorDict({}, batch_size=[]))) TensorDict( fields={ x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
另一个特性是传递字典作为输入键,以控制将值分派到特定的关键字参数。
示例
>>> module = TensorDictModule(lambda x, *, y: x+y, ... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], out_to_in_map=False ... ) >>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, [])) >>> td['z'] tensor(3.)
如果 out_to_in_map 设置为
True
,则 in_keys 映射将被反转。这样,就可以将同一个输入键用于不同的关键字参数。示例
>>> module = TensorDictModule(lambda x, *, y, z: x+y+z, ... in_keys={'x': '1', 'y': '2', z: '2'}, out_keys=['t'], out_to_in_map=True ... ) >>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, [])) >>> td['t'] tensor(5.)
我们可以指定模块内要调用的方法。与使用 lambda 函数或类似方式包装模块方法相比,这有一个优点,即模块的属性(params、buffers、submodules)将暴露出来。
示例
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> from torch import nn >>> import torch >>> >>> class MyNet(nn.Module): ... def my_func(self, tensor: torch.Tensor, *, an_integer: int): ... return tensor + an_integer ... >>> s = Seq( ... { ... "a": lambda td: td+1, ... "b": lambda td: td * 2, ... "c": Mod(MyNet(), in_keys=["a"], out_keys=["b"], method="my_func", method_kwargs={"an_integer": 2}), ... } ... ) >>> td = s(TensorDict(a=0)) >>> print(td) >>> >>> assert td["b"] == 4
对 tensordict 模块进行函数式调用很容易。
示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) >>> params = TensorDict.from_module(td_module) >>> # functional API >>> with params.to_module(td_module): ... td_functional = td_module(td.clone()) >>> print(td_functional) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- 在有状态情况下。
>>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) >>> td_stateful = td_module(td.clone()) >>> print(td_stateful) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- forward(tensordict: TensorDictBase = None, args=None, *, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase ¶
当 tensordict 参数未设置时,kwargs 用于创建 TensorDict 的实例。