快捷方式

CudaGraphModule

class tensordict.nn.CudaGraphModule(module: Callable[[Union[List[Tensor], TensorDictBase]], None], warmup: int = 2, in_keys: Optional[List[NestedKey]] = None, out_keys: Optional[List[NestedKey]] = None, device: Optional[device] = None)

PyTorch 可调用对象的 cudagraph 包装器。

CudaGraphModule 是一个包装器类,它为 PyTorch 可调用对象提供用户友好的 CUDA 图接口。

警告

CudaGraphModule 是一个原型功能,其 API 限制在未来可能会发生变化。

此类为 CUDA 图提供了用户友好的接口,允许 GPU 上操作的快速、无 CPU 开销执行。它运行对函数输入的必要检查,并提供类似 nn.Module 的 API 来运行

警告

此模块要求包装的函数满足一些要求。用户有责任确保所有这些要求都已满足。

  • 函数不能有动态控制流。例如,以下代码片段将在 CudaGraphModule 中包装失败

    >>> def func(x):
    ...     if x.norm() > 1:
    ...         return x + 1
    ...     else:
    ...         return x - 1
    

    幸运的是,PyTorch 在大多数情况下都提供了解决方案

    >>> def func(x):
    ...     return torch.where(x.norm() > 1, x + 1, x - 1)
    
  • 函数必须执行一个可以使用相同缓冲区精确重放的代码。这意味着不支持动态形状(输入中或代码执行期间形状的变化)。换句话说,输入必须具有恒定的形状。

  • 函数的输出必须是分离的。如果需要调用优化器,请将其放在输入函数中。例如,以下函数是一个有效的运算符

    >>> def func(x, y):
    ...     optim.zero_grad()
    ...     loss_val = loss_fn(x, y)
    ...     loss_val.backward()
    ...     optim.step()
    ...     return loss_val.detach()
    
  • 输入不应可微分。如果你需要使用 nn.Parameters(或一般可微分张量),只需编写一个将它们用作全局值而不是将它们作为输入传递的函数

    >>> x = nn.Parameter(torch.randn(()))
    >>> optim = Adam([x], lr=1)
    >>> def func(): # right
    ...     optim.zero_grad()
    ...     (x+1).backward()
    ...     optim.step()
    >>> def func(x): # wrong
    ...     optim.zero_grad()
    ...     (x+1).backward()
    ...     optim.step()
    
  • 作为张量或 tensordict 的 args 和 kwargs 可能会改变(前提是设备和形状匹配),但非张量 args 和 kwargs 不应改变。例如,如果函数接收一个字符串输入并且该输入在任何时候都被更改,模块将静默地使用捕获 cudagraph 时使用的字符串执行代码。唯一支持的关键字参数是 tensordict_out,以防输入是 tensordict。

  • 如果模块是 TensorDictModuleBase 实例,并且输出 ID 与输入 ID 匹配,那么在调用 CudaGraphModule 时将保留此身份。在所有其他情况下,输出将被克隆,无论其元素是否匹配输入中的一个或多个。

警告

CudaGraphModule 不是 Module,其设计目的是为了阻止收集输入模块的参数并将其传递给优化器。

参数:
  • module (Callable) – 接收张量(或 tensordict)作为输入并输出 PyTreeable 张量集合的函数。如果提供 tensordict,则模块也可以使用关键字参数运行(请参见下面的示例)。

  • warmup (int, optional) – 如果模块已编译(编译后的模块应在被 cudagraphs 捕获之前运行几次),则进行预热的次数。默认为所有模块的 2

  • in_keys (list of NestedKeys) –

    输入键,如果模块以 TensorDict 作为输入。如果此值存在,则默认为 module.in_keys,否则为 None

    注意

    如果提供了 in_keys 但为空,则假定模块接收 tensordict 作为输入。这足以让 CudaGraphModule 意识到该函数应被视为 TensorDictModule,但关键字参数不会被分派。请参见下面的示例。

  • out_keys (list of NestedKeys) – 输出键,如果模块以 TensorDict 作为输出。如果此值存在,则默认为 module.out_keys,否则为 None

  • device (torch.device, optional) – 要使用的流的设备。

示例

>>> # Wrap a simple function
>>> def func(x):
...     return x + 1
>>> func = CudaGraphModule(func)
>>> x = torch.rand((), device='cuda')
>>> out = func(x)
>>> assert isinstance(out, torch.Tensor)
>>> assert out == x+1
>>> # Wrap a tensordict module
>>> func = TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"])
>>> func = CudaGraphModule(func)
>>> # This can be called either with a TensorDict or regular keyword arguments alike
>>> y = func(x=x)
>>> td = TensorDict(x=x)
>>> td = func(td)

注意

关于调试 CudaGraphModule 错误的提示

  • 诸如operation would make the legacy stream depend on a capturing blocking stream(操作将使旧流依赖于捕获阻塞流)之类的错误

    应首先使用非编译版本进行调试(编译代码将隐藏负责跨流依赖的代码部分)。这可能是因为您正在进行跨设备操作,导致捕获流依赖于其他流。

  • 诸如Cannot call CUDAGeneratorImpl::current_seed during CUDA graph capture(在 CUDA 图捕获期间无法调用 CUDAGeneratorImpl::current_seed)之类的错误,或其他起源于

    编译器(而非编译的代码!)的错误,可能指向在重新编译时发生的图捕获。使用 TORCH_LOGS=”+recompiles” python myscrip.py 捕获重新编译,并尝试修复它们。通常,请确保您使用了足够多的预热步骤。如果您在此类问题上遇到困难,请提交一个 issue。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源