评价此页

graph#

class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[source]#

上下文管理器,它将 CUDA 工作捕获到 torch.cuda.CUDAGraph 对象中以供以后重放。

有关通用介绍、详细用法和限制,请参阅 CUDA Graphs

参数
  • cuda_graph (torch.cuda.CUDAGraph) – 用于捕获的 Graph 对象。

  • pool (optional) – 不透明令牌(通过调用 graph_pool_handle()other_Graph_instance.pool() 返回),指示此 graph 的捕获可能共享指定池的内存。请参阅 Graph 内存管理

  • stream (torch.cuda.Stream, optional) – 如果提供,将在上下文中设置为当前流。如果未提供,则 graph 会将其自身的内部辅助流设置为上下文中的当前流。

  • capture_error_mode (str, optional) – 指定 graph 捕获流的 cudaStreamCaptureMode。可以是“global”、“thread_local”或“relaxed”。在 cuda graph 捕获期间,某些操作(如 cudaMalloc)可能不安全。“global”会因其他线程中的操作而报错,“thread_local”只会因当前线程中的操作而报错,“relaxed”则不会因操作而报错。除非您熟悉 cudaStreamCaptureMode,否则请勿更改此设置。

注意

为了有效地共享内存,如果您传递一个由先前捕获使用的 pool,并且先前的捕获使用了显式的 stream 参数,那么您应该将相同的 stream 参数传递给此次捕获。

警告

此 API 处于 Beta 版,未来版本中可能会更改。