graph#
- class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[source]#
用于将 CUDA 工作捕获到
torch.cuda.CUDAGraph
对象中以供以后重放的上下文管理器。有关通用介绍、详细使用和约束,请参阅 CUDA Graphs。
- 参数
cuda_graph (torch.cuda.CUDAGraph) – 用于捕获的 Graph 对象。
pool (optional) – 不透明标记(从调用
graph_pool_handle()
或other_Graph_instance.pool()
返回)提示此 graph 的捕获可能共享来自指定池的内存。请参阅 Graph 内存管理。stream (torch.cuda.Stream, optional) – 如果提供,将在上下文中设置为当前流。如果未提供,
graph
会将其内部的 side stream 设置为上下文中的当前流。capture_error_mode (str, optional) – 指定捕获流的 cudaStreamCaptureMode。可以是“global”、“thread_local”或“relaxed”。在 cuda graph 捕获期间,某些操作(如 cudaMalloc)可能不安全。“global”会在其他线程的操作上引发错误,“thread_local”仅会针对当前线程的操作引发错误,而“relaxed”则不会引发错误。除非您熟悉 cudaStreamCaptureMode,否则请勿更改此设置。
注意
为了实现有效的内存共享,如果您传递了一个前次捕获使用的
pool
,并且前次捕获使用了显式的stream
参数,您应该将相同的stream
参数传递给本次捕获。警告
此 API 处于 Beta 版,未来版本中可能会更改。