as_tensordict_module¶
- class tensordict.nn.as_tensordict_module(*, in_keys: Union[List[NestedKey], NestedKey], out_keys: Union[List[NestedKey], NestedKey])¶
一个将函数转换为 TensorDictModule 的装饰器。
- 参数:
in_keys (List[NestedKey] | NestedKey | None, 可选) – 结果 TensorDictModule 的输入键。
out_keys (List[NestedKey] | NestedKey | None, 可选) – 结果 TensorDictModule 的输出键。
- 返回:
一个可以应用于函数以将其转换为 TensorDictModule 的装饰器。
- 返回类型:
Callable
示例
>>> class MyClass: ... @as_tensordict_module(in_keys="c", out_keys="d") ... def my_method(self, c): ... return c + 1 >>> obj = MyClass() >>> result = obj.my_method(TensorDict(c=0)) >>> print(result["d"]) # prints: 1 >>> @as_tensordict_module(in_keys="c", out_keys="d") ... def my_function(c): ... return c + 1 >>> result = my_function(TensorDict(c=0)) >>> print(result["d"]) # prints: 1