CudaGraphModule¶
- class tensordict.nn.CudaGraphModule(module: Callable[[Union[List[Tensor], TensorDictBase]], None], warmup: int = 2, in_keys: Optional[List[NestedKey]] = None, out_keys: Optional[List[NestedKey]] = None, device: Optional[device] = None)¶
CUDA 图的 PyTorch 可调用包装器。
CudaGraphModule
是一个包装类,它为 PyTorch 可调用对象提供了用户友好的 CUDA 图接口。警告
CudaGraphModule
是一个原型特性,其 API 限制在未来可能会发生变化。此类提供了一个用户友好的 CUDA 图接口,允许在 GPU 上进行快速、无 CPU 开销的 GPU 操作。它会运行函数输入的必要检查,并提供类似 nn.Module 的 API 来运行
警告
此模块要求被包装的函数满足几个要求。用户有责任确保所有这些要求都已满足。
函数不能有动态控制流。例如,以下代码片段将无法被 CudaGraphModule 包装
>>> def func(x): ... if x.norm() > 1: ... return x + 1 ... else: ... return x - 1
幸运的是,PyTorch 在大多数情况下都提供了解决方案
>>> def func(x): ... return torch.where(x.norm() > 1, x + 1, x - 1)
函数必须执行一段代码,这段代码可以使用相同的缓冲区精确地重新运行。这意味着不支持动态形状(输入中或代码执行期间形状的变化)。换句话说,输入必须具有恒定的形状。
函数的输出必须是已分离的。如果需要调用优化器,请将其放在输入函数中。例如,以下函数是一个有效的操作符
>>> def func(x, y): ... optim.zero_grad() ... loss_val = loss_fn(x, y) ... loss_val.backward() ... optim.step() ... return loss_val.detach()
输入不应可微分。如果你需要使用 nn.Parameters(或一般的可微分张量),只需编写一个将它们用作全局值的函数,而不是将它们作为输入传递
>>> x = nn.Parameter(torch.randn(())) >>> optim = Adam([x], lr=1) >>> def func(): # right ... optim.zero_grad() ... (x+1).backward() ... optim.step() >>> def func(x): # wrong ... optim.zero_grad() ... (x+1).backward() ... optim.step()
作为张量或 tensordict 的参数和关键字参数可能会改变(前提是设备和形状匹配),但非张量参数和关键字参数不应改变。例如,如果函数接收一个字符串输入,并且在任何时候该输入都发生了更改,则模块将使用 cudagraph 捕获期间使用的字符串静默执行代码。唯一支持的关键字参数是 tensordict_out,以防输入是 tensordict。
如果模块是
TensorDictModuleBase
实例,并且输出 ID 与输入 ID 匹配,则在调用CudaGraphModule
时将保留此身份。在所有其他情况下,输出将被克隆,而不管其元素是否匹配输入中的一个或多个。
警告
CudaGraphModule
按设计不是一个Module
,以避免收集输入模块的参数并将其传递给优化器。- 参数:
module (Callable) – 一个接收张量(或 tensordict)作为输入并输出 PyTreeable 张量集合的函数。如果提供了 tensordict,则模块也可以使用关键字参数运行(参见下面的示例)。
warmup (int, optional) – 如果模块已编译,则预热步数(编译的模块应运行几次,然后才能被 cudagraphs 捕获)。对于所有模块,默认为
2
。in_keys (list of NestedKeys) –
输入键,如果模块以 TensorDict 作为输入。如果此值存在,则默认为
module.in_keys
,否则为None
。注意
如果提供了
in_keys
但为空,则假定模块接收 tensordict 作为输入。这足以使CudaGraphModule
知道该函数应被视为 TensorDictModule,但关键字参数将不会被分发。有关示例,请参见下文。out_keys (list of NestedKeys) – 输出键,如果模块以 TensorDict 作为输出。如果此值存在,则默认为
module.out_keys
,否则为None
。device (torch.device, optional) – 流的设备。
示例
>>> # Wrap a simple function >>> def func(x): ... return x + 1 >>> func = CudaGraphModule(func) >>> x = torch.rand((), device='cuda') >>> out = func(x) >>> assert isinstance(out, torch.Tensor) >>> assert out == x+1 >>> # Wrap a tensordict module >>> func = TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]) >>> func = CudaGraphModule(func) >>> # This can be called either with a TensorDict or regular keyword arguments alike >>> y = func(x=x) >>> td = TensorDict(x=x) >>> td = func(td)
注意
关于调试 CudaGraphModule 错误的提示
- 诸如 operation would make the legacy stream depend on a capturing blocking stream 等错误的调试方法是
首先使用非编译版本的代码进行调试(编译后的代码会隐藏负责跨流依赖关系的代码部分)。这可能是因为您正在进行跨设备操作,导致捕获流依赖于其他流。
- 诸如 Cannot call CUDAGeneratorImpl::current_seed during CUDA graph capture 或其他源自
编译器的错误(而不是编译后的代码!)可能指向在重新编译期间捕获了图。使用 TORCH_LOGS=”+recompiles” python myscrip.py 捕获重新编译,并尝试修复它们。通常,请确保使用了足够的预热步数。如果您在此遇到困难,请提交一个 issue。