tensordict.nn.set_skip_existing¶
- class tensordict.nn.set_skip_existing(mode: bool | None = True, in_key_attr='in_keys', out_key_attr='out_keys')¶
用于在 TensorDict 图中跳过现有节点的上下文管理器。
用作上下文管理器时,它会将 `skip_existing()` 的值设置为指定的 `mode`,让用户能够编写相应的代码来检查全局值并据此执行代码。
用作方法装饰器时,它会检查 tensordict 的输入键,如果 `skip_existing()` 调用返回 `True`,则当所有输出键都已存在时,将跳过该方法。此装饰器不适用于不遵循以下签名的函数:`def fun(self, tensordict, *args, **kwargs)`。
- 参数:
mode (bool, optional) – 如果为 `True`,则表示图中的现有条目不会被覆盖,除非它们是部分存在的。`skip_existing()` 将返回 `True`。如果为 `False`,则不会执行检查。如果为 `None`,则 `skip_existing()` 的值不会改变。这仅用于装饰方法,并允许它们的行为依赖于上下文管理器中的同一类(参见下面的示例)。默认为 `True`。
in_key_attr (str, optional) – 被装饰模块方法中的输入键列表属性的名称。默认为 `'in_keys'`。
out_key_attr (str, optional) – 被装饰模块方法中的输出键列表属性的名称。默认为 `'out_keys'`。
示例
>>> with set_skip_existing(): ... if skip_existing(): ... print("True") ... else: ... print("False") ... True >>> print("calling from outside:", skip_existing()) calling from outside: False
此类也可作为装饰器使用
示例
>>> from tensordict import TensorDict >>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase >>> class MyModule(TensorDictModuleBase): ... in_keys = [] ... out_keys = ["out"] ... @set_skip_existing() ... def forward(self, tensordict): ... print("hello") ... tensordict.set("out", torch.zeros(())) ... return tensordict >>> module = MyModule() >>> module(TensorDict({"out": torch.zeros(())}, [])) # does not print anything TensorDict( fields={ out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> module(TensorDict()) # prints hello hello TensorDict( fields={ out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
用 `mode` 设置为 `None` 来装饰一个方法,当你想让上下文管理器从外部负责跳过内容时非常有用。
示例
>>> from tensordict import TensorDict >>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase >>> class MyModule(TensorDictModuleBase): ... in_keys = [] ... out_keys = ["out"] ... @set_skip_existing(None) ... def forward(self, tensordict): ... print("hello") ... tensordict.set("out", torch.zeros(())) ... return tensordict >>> module = MyModule() >>> _ = module(TensorDict({"out": torch.zeros(())}, [])) # prints "hello" hello >>> with set_skip_existing(True): ... _ = module(TensorDict({"out": torch.zeros(())}, [])) # no print
注意
为了允许模块具有相同的输入和输出键而不至于错误地忽略子图,当输出键也是输入键时,`@set_skip_existing(True)` 将被禁用。
>>> class MyModule(TensorDictModuleBase): ... in_keys = ["out"] ... out_keys = ["out"] ... @set_skip_existing() ... def forward(self, tensordict): ... print("calling the method!") ... return tensordict ... >>> module = MyModule() >>> module(TensorDict({"out": torch.zeros(())}, [])) # does not print anything calling the method! TensorDict( fields={ out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)