评价此页

torch.cuda.make_graphed_callables#

torch.cuda.make_graphed_callables(callables: Union[Module, Callable[[...], object]], sample_args: tuple[torch.Tensor, ...], num_warmup_iters: int = 3, allow_unused_input: bool = False, pool: Optional[_POOL_HANDLE] = None) Union[Module, Callable[[...], object]][source]#
torch.cuda.make_graphed_callables(callables: tuple[Union[torch.nn.modules.module.Module, Callable[..., object]], ...], sample_args: tuple[tuple[torch.Tensor, ...], ...], num_warmup_iters: int = 3, allow_unused_input: bool = False, pool: Optional[_POOL_HANDLE] = None) tuple[Union[torch.nn.modules.module.Module, Callable[..., object]], ...]

接受可调用对象(函数或nn.Module)并返回图化版本。

每个图化可调用对象的正向传播将源可调用对象的正向 CUDA 工作作为单个 autograd 节点中的 CUDA 图运行。

图化可调用对象的前向传播还将一个反向节点附加到 autograd 图中。在反向传播期间,此节点将可调用对象的反向工作作为 CUDA 图运行。

因此,每个图化可调用对象都应成为 autograd 启用的训练循环中其源可调用对象的即插即用替代品。

有关详细用法和限制,请参阅部分网络捕获

如果传递多个可调用对象的元组,则它们的捕获将使用相同的内存池。请参阅图内存管理了解何时适用。

参数
  • callablestorch.nn.ModulePython 函数,或 元组 中的 这些)– 要图化的可调用对象或可调用对象。有关传递可调用对象元组何时适用的信息,请参阅图内存管理。如果传递可调用对象元组,则元组中的顺序必须与实时工作负载中的运行顺序相同。

  • sample_args元组 中的 Tensor,或 元组 中的 元组 中的 Tensor)– 为每个可调用对象采样参数。如果传递单个可调用对象,则 sample_args 必须是参数 Tensor 的单个元组。如果传递了可调用对象元组,则 sample_args 必须是参数 Tensor 的元组元组。

  • num_warmup_itersint)– 预热迭代次数。目前,DataDistributedParallel 需要 11 次预热迭代。默认值:3

  • allow_unused_inputbool)– 如果为 False,则指定未在计算输出时使用的输入(因此其 grad 始终为零)将引发错误。默认为 False。

  • pool可选)– Token(由 graph_pool_handle()other_Graph_instance.pool() 返回)提示此图可能与指示的池共享内存。请参阅图内存管理

注意

sample_args 中每个 Tensor 的 requires_grad 状态必须与训练循环中相应真实输入的预期状态匹配。

警告

此 API 处于 Beta 版,未来版本中可能会更改。

警告

每个可调用对象的 sample_args 必须只包含 Tensor。不允许其他类型。

警告

返回的可调用对象不支持高阶微分(例如,二次反向传播)。

警告

在传递给 make_graphed_callables() 的任何 Module 中,只有参数可以是可训练的。缓冲区必须具有 requires_grad=False

警告

在通过 make_graphed_callables() 传递 torch.nn.Module 后,您不能添加或删除该 Module 的任何参数或缓冲区。

警告

传递给 make_graphed_callables()torch.nn.Module 在传递时不得在其上注册模块挂钩。但是,在传递给 make_graphed_callables() 之后在模块上注册挂钩是允许的。

警告

运行图化可调用对象时,必须以其 sample_args 中出现的相同顺序和格式传递其参数。

警告

make_graphed_callables() 中的自动混合精度仅在禁用缓存时受支持。上下文管理器 torch.cuda.amp.autocast() 必须将 cache_enabled=False