评价此页

torch.cuda.make_graphed_callables#

torch.cuda.make_graphed_callables(callables: Module | Callable[[...], object], sample_args: tuple[Tensor, ...], num_warmup_iters: int = 3, allow_unused_input: bool = False, pool: _POOL_HANDLE | None = None) Module | Callable[[...], object][源]#
torch.cuda.make_graphed_callables(callables: tuple[Module | Callable[[...], object], ...], sample_args: tuple[tuple[Tensor, ...], ...], num_warmup_iters: int = 3, allow_unused_input: bool = False, pool: _POOL_HANDLE | None = None) tuple[Module | Callable[[...], object], ...]

接受可调用对象(函数或 nn.Module)并返回它们的图化版本。

每个图化可调用对象的前向传播会将其源可调用对象的前向 CUDA 工作作为 CUDA 图在一个单独的自动求导节点内运行。

图化可调用对象的前向传播还会向自动求导图添加一个反向传播节点。在反向传播期间,此节点会以 CUDA 图的形式运行可调用对象的反向传播工作。

因此,在启用自动求导的训练循环中,每个图化可调用对象都应是其源可调用对象的即插即用替代。

有关详细用法和限制,请参阅 部分网络捕获

如果您传入一个包含多个可调用对象的元组,它们的捕获将使用相同的内存池。有关何时适合这样做,请参阅 图内存管理

参数:
  • callables (torch.nn.ModulePython 函数,或它们的 tuple) – 要图化的可调用对象或可调用对象元组。有关何时适合传入可调用对象元组,请参阅 图内存管理。如果您传入一个可调用对象元组,它们在元组中的顺序必须与它们在实际工作负载中运行的顺序相同。

  • sample_args (Tensortuple,或 Tensortupletuple) – 每个可调用对象的示例参数。如果传入单个可调用对象,sample_args 必须是一个由参数 Tensor 组成的元组。如果传入一个可调用对象元组,sample_args 必须是一个由参数 Tensor 元组组成的元组。

  • num_warmup_iters (int) – 热身迭代的次数。目前,DataDistributedParallel 需要 11 次迭代进行热身。默认值:3

  • allow_unused_input (bool) – 如果为 False,指定在计算输出时未使用的输入(因此它们的梯度始终为零)将引发错误。默认值为 False。

  • pool可选)– Token(由 graph_pool_handle()other_Graph_instance.pool() 返回)指示此图可能与指定的池共享内存。请参阅 图内存管理

注意

sample_args 中,每个 Tensor 的 requires_grad 状态必须与训练循环中相应实际输入的预期状态相匹配。

警告

此 API 处于 Beta 版,未来版本中可能会更改。

警告

每个可调用对象的 sample_args 必须只包含 Tensor。不允许其他类型。

警告

返回的可调用对象不支持高阶微分(例如,二次反向传播)。

警告

在传递给 make_graphed_callables() 的任何 Module 中,只有参数可以是可训练的。缓冲区必须具有 requires_grad=False

警告

torch.nn.Module 传递给 make_graphed_callables() 后,您不得添加或删除该 Module 的任何参数或缓冲区。

警告

传递给 make_graphed_callables()torch.nn.Module 在传递时不得注册模块钩子。但是,在将模块传递给 make_graphed_callables() *之后* 注册钩子是允许的。

警告

运行图化可调用对象时,您必须按照其 sample_args 中出现的相同顺序和格式传递其参数。

警告

自动混合精度在 make_graphed_callables() 中仅在禁用缓存时受支持。上下文管理器 torch.cuda.amp.autocast() 必须将 cache_enabled 设置为 False