评价此页

graph#

class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[source]#

上下文管理器,用于将 CUDA 工作捕获到 torch.cuda.CUDAGraph 对象中,以便稍后回放。

有关通用介绍、详细用法和限制,请参阅 CUDA Graphs

参数:
  • cuda_graph (torch.cuda.CUDAGraph) – 用于捕获的图对象。

  • pool (可选) – 一个不透明的令牌(由调用 graph_pool_handle()other_Graph_instance.pool() 返回),提示该图的捕获可能会共享来自指定池的内存。请参阅 图内存管理 (Graph memory management)

  • stream (torch.cuda.Stream, 可选) – 如果提供,它将被设置为上下文中的当前流。如果未提供,graph 会将其内部辅助流设置为上下文中的当前流。

  • capture_error_mode (str, 可选) – 指定图捕获流的 cudaStreamCaptureMode。可以是“global”、“thread_local”或“relaxed”。在 CUDA 图捕获期间,某些操作(如 cudaMalloc)可能是不安全的。“global”会对其他线程的操作报错,“thread_local”仅对当前线程的操作报错,“relaxed”则不会对操作报错。除非你熟悉 cudaStreamCaptureMode,否则请勿更改此设置。

注意

为了实现有效的内存共享,如果你传递了一个先前捕获所使用的 pool,且该先前的捕获使用了显式的 stream 参数,你应该为当前的捕获传递相同的 stream 参数。

警告

此 API 处于 Beta 版,未来版本中可能会更改。