torch.compiler.skip_guard_on_all_nn_modules_unsafe#
- torch.compiler.skip_guard_on_all_nn_modules_unsafe(guard_entries)[源代码]#
一个通用的函数,用于跳过对所有 nn 模块的 guard(保护),包括用户定义的和内置的 nn 模块(例如 torch.nn.Linear)。默认情况下使用此函数是不安全的。但对于大多数 torch.compile 用户来说,模型代码不会修改 nn 模块的属性。他们可以通过使用此 API 来受益于 guard 延迟开销的减少。
要使用此 API,请在调用 torch.compile 时使用 guard_filter_fn 参数
>> opt_mod = torch.compile( >> mod, >> options={“guard_filter_fn”: torch.compiler.skip_guard_on_all_nn_modules_unsafe}, >> )