WrapModule¶
- class tensordict.nn.WrapModule(*args, **kwargs)¶
一个包装任何处理 TensorDict 实例的可调用对象的包装器。
当构建
TensorDictSequential
堆栈以及转换需要整个 TensorDict 实例可见时,此包装器非常有用。- 参数:
func (Callable[[TensorDictBase], TensorDictBase]) – 一个可调用函数,它接收一个 TensorDictBase 实例并返回一个转换后的 TensorDictBase 实例。
- 关键字参数:
inplace (bool, optional) – 如果为
True
,则输入 TensorDict 将被就地修改。否则,将返回一个新的 TensorDict(如果函数不就地修改并返回它)。默认为False
。in_keys (list of NestedKey, optional) – 如果提供,则指示模块读取哪些条目。这不会被检查,仅用于向
TensorDictSequential
提供有关被包装模块输入键的信息。默认为 []。out_keys (list of NestedKey, optional) – 如果提供,则指示模块写入哪些条目。这不会被检查,仅用于向
TensorDictSequential
提供有关被包装模块输出键的信息。默认为 []。
示例
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule >>> seq = Seq( ... Mod(lambda x: x * 2, in_keys=["x"], out_keys=["y"]), ... WrapModule(lambda td: td.reshape(-1)), ... ) >>> td = TensorDict(x=torch.ones(3, 4, 5), batch_size=[3, 4]) >>> td = Seq(td) >>> assert td.shape == (12,) >>> assert (td["y"] == 2).all() >>> assert td["y"].shape == (12, 5)
- forward(data: TensorDictBase) TensorDictBase ¶
定义每次调用时执行的计算。
所有子类都应重写此方法。
注意
虽然前向传递的实现需要在此函数中定义,但之后应调用
Module
实例而不是此函数,因为前者负责运行已注册的钩子,而后者会静默地忽略它们。