评价此页

torch.utils.checkpoint#

创建于:2025 年 6 月 16 日 | 最后更新于:2025 年 6 月 16 日

注意

检查点是通过在反向传播期间为每个检查点段重新运行前向传递段来实现的。这可能导致诸如 RNG 状态之类的持久状态比没有检查点时更超前。默认情况下,检查点包含用于管理 RNG 状态的逻辑,以便使用 RNG(例如通过 dropout)的检查点传递与非检查点传递相比具有确定性的输出。 the logic to stash and restore RNG states can incur a moderate performance hit depending on the runtime of checkpointed operations. 如果不需要与非检查点传递相比具有确定性的输出,请将 `preserve_rng_state=False` 传递给 `checkpoint` 或 `checkpoint_sequential`,以省略在每个检查点期间存储和恢复 RNG 状态。

存储逻辑会为 CPU 和另一种设备类型(通过 `_infer_device_type` 从除 CPU 张量之外的张量参数推断设备类型)保存和恢复 RNG 状态到 `run_fn`。如果存在多种设备,则设备状态仅为一种设备类型保存,其余设备将被忽略。因此,如果任何被检查点的函数涉及随机性,这可能会导致梯度不正确。(请注意,如果检测到的设备中包含 CUDA 设备,则会优先选择 CUDA 设备;否则,将选择遇到的第一个设备。)如果没有 CPU 张量,将保存和恢复默认设备类型状态(默认值为 `cuda`,可以通过 `DefaultDeviceType` 设置为其他设备)。但是,该逻辑无法预测用户是否会在 `run_fn` 内部将张量移动到新设备。因此,如果在 `run_fn` 中移动张量到新设备(“新”表示不属于 [当前设备 + 张量参数的设备] 集合),则永远无法保证与非检查点传递相比具有确定性的输出。

torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, early_stop=True, **kwargs)[source]#

检查点模型或模型的一部分。

激活检查点是一种以计算换内存的技术。在反向传播期间计算梯度时,检查点区域的前向计算不保存用于反向传播的张量,而是在反向传播期间重新计算它们。激活检查点可以应用于模型的任何部分。

目前提供两种检查点实现,由 `use_reentrant` 参数确定。建议使用 `use_reentrant=False`。请参阅下面的注释以讨论它们的区别。

警告

如果反向传播期间的 `function` 调用与前向传递不同,例如由于全局变量,则检查点版本可能不相等,可能导致引发错误或导致静默的错误梯度。

警告

`use_reentrant` 参数应显式传递。在 2.9 版本中,如果未传递 `use_reentrant`,我们将引发异常。如果您正在使用 `use_reentrant=True` 变体,请参阅下面的注释以了解重要的注意事项和潜在限制。

注意

检查点的可重入变体 (`use_reentrant=True`) 和非可重入变体 (`use_reentrant=False`) 在以下方面有所不同:

  • 非可重入检查点在重新计算完所有必需的中间激活后立即停止重新计算。此功能默认启用,但可以通过 `set_checkpoint_early_stop()` 禁用。可重入检查点在反向传播期间始终完全重新计算 `function`。

  • 可重入变体在前向传递期间不记录 autograd 图,因为它在 `torch.no_grad()` 下运行前向传递。非可重入版本会记录 autograd 图,允许在检查点区域内的图上执行反向传播。

  • 可重入检查点仅支持 `torch.autograd.backward()` API 进行反向传播,而不带其 `inputs` 参数,而后非可重入版本支持所有方式的反向传播。

  • 可重入变体至少需要一个输入和输出具有 `requires_grad=True`。如果此条件不满足,则模型的检查点部分将没有梯度。非可重入版本没有此要求。

  • 可重入版本不将嵌套结构(例如,自定义对象、列表、字典等)中的张量视为参与 autograd,而非可重入版本则将其视为参与。

  • 可重入检查点不支持包含从计算图中分离的张量的检查点区域,而非常可重入版本支持。对于可重入变体,如果检查点段包含使用 `detach()` 或 `torch.no_grad()` 分离的张量,则反向传播将引发错误。这是因为 `checkpoint` 使所有输出都需要梯度,当模型中定义了一个张量不应有梯度时,这会导致问题。为避免此问题,请在 `checkpoint` 函数外部分离张量。

参数
  • function – 描述模型或模型一部分前向传递中运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在 LSTM 中,如果用户传递 `(activation, hidden)`,`function` 应该正确地将第一个输入用作 `activation`,第二个输入用作 `hidden`。

  • args – 包含 `function` 输入的元组。

关键字参数
  • preserve_rng_state (bool, optional) – 在每个检查点期间省略存储和恢复 RNG 状态。请注意,在 `torch.compile` 下,此标志无效,我们始终保留 RNG 状态。默认值: `True`。

  • use_reentrant (bool) – 指定是否使用需要可重入 autograd 的激活检查点变体。此参数应显式传递。在 2.9 版本中,如果未传递 `use_reentrant`,我们将引发异常。如果 `use_reentrant=False`,`checkpoint` 将使用不需要可重入 autograd 的实现。这使得 `checkpoint` 支持其他功能,例如与 `torch.autograd.grad` 正常工作以及支持传递给检查点函数的关键字参数。

  • context_fn (Callable, optional) – 返回两个上下文管理器元组的可调用对象。函数及其重计算将在第一个和第二个上下文管理器下分别运行。此参数仅在 `use_reentrant=False` 时受支持。

  • determinism_check (str, optional) – 指定要执行的确定性检查的字符串。默认设置为 `"default"`,它将重计算张量的形状、数据类型和设备与保存的张量进行比较。要关闭此检查,请指定 `"none"`。目前这是唯一支持的两个值。如果您希望看到更多确定性检查,请打开一个 issue。此参数仅在 `use_reentrant=False` 时受支持。如果 `use_reentrant=True`,则确定性检查始终禁用。

  • debug (bool, optional) – 如果为 `True`,错误消息还将包括在原始前向计算和重计算期间运行的操作的跟踪。此参数仅在 `use_reentrant=False` 时受支持。

  • early_stop (bool, optional) – 如果为 `True`,非可重入检查点会尽快停止重计算,因为它已经计算了所有需要的张量。如果 `use_reentrant=True`,则忽略此参数。可以使用 `set_checkpoint_early_stop()` 上下文管理器全局覆盖。默认值: `True`。

返回

在 `*args` 上运行 `function` 的输出。

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source]#

检查点顺序模型以节省内存。

顺序模型按顺序(顺序)执行一系列模块/函数。因此,我们可以将此类模型划分为多个段并检查每个段。除最后一个段外,所有段都不会存储中间激活。每个检查点段的输入将被保存,以便在反向传播期间重新运行该段。

警告

`use_reentrant` 参数应显式传递。在 2.9 版本中,如果未传递 `use_reentrant`,我们将引发异常。如果您正在使用 `use_reentrant=True` 变体,请参阅 :func:`~torch.utils.checkpoint.checkpoint` 以了解此变体的​​重要注意事项和限制。建议使用 `use_reentrant=False`。

参数
  • functions – 一个 `torch.nn.Sequential` 或模块/函数列表(构成模型),按顺序运行。

  • segments – 在模型中创建的块的数量。

  • input – 输入到 `functions` 的一个张量。

  • preserve_rng_state (bool, optional) – 在每个检查点期间省略存储和恢复 RNG 状态。默认值: `True`。

  • use_reentrant (bool) – 指定是否使用需要可重入 autograd 的激活检查点变体。此参数应显式传递。在 2.5 版本中,如果未传递 `use_reentrant`,我们将引发异常。如果 `use_reentrant=False`,`checkpoint` 将使用不需要可重入 autograd 的实现。这使得 `checkpoint` 支持其他功能,例如与 `torch.autograd.grad` 正常工作以及支持传递给检查点函数的关键字参数。

返回

在 `*inputs` 上顺序运行 `functions` 的输出。

示例

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]#

上下文管理器,用于设置检查点在运行时是否应打印额外的调试信息。有关更多信息,请参阅 `checkpoint()` 的 `debug` 标志。请注意,设置时,此上下文管理器会覆盖传递给检查点的 `debug` 值。要推迟到本地设置,请将 `None` 传递给此上下文。

参数

enabled (bool) – 检查点是否应打印调试信息。默认为“None”。

class torch.utils.checkpoint.CheckpointPolicy(value)[source]#

用于指定反向传播期间检查点策略的枚举。

支持以下策略:

  • `{MUST,PREFER}_SAVE`: 操作的输出将在前向传递期间保存,并且不会在反向传递期间重新计算。

  • `{MUST,PREFER}_RECOMPUTE`: 操作的输出不会在前向传递期间保存,并且将在反向传递期间重新计算。

使用 `MUST_*` 而非 `PREFER_*` 来指示该策略不应被 `torch.compile` 等其他子系统覆盖。

注意

一个始终返回 `PREFER_RECOMPUTE` 的策略函数等同于标准的检查点。

一个始终返回 `PREFER_SAVE` 的策略函数不等于不使用检查点。使用这样的策略将保存额外的张量,而不仅仅是那些实际上用于梯度计算的张量。

class torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[source]#

在选择性检查点期间传递给策略函数的上下文。

此类用于在选择性检查点期间将相关元数据传递给策略函数。元数据包括策略函数当前调用是否在重计算期间。

示例

>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    print(ctx.is_recompute)
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )
torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[source]#

帮助程序,用于在激活检查点期间避免重新计算某些操作。

与 `torch.utils.checkpoint.checkpoint` 一起使用此功能,以控制在反向传播期间重新计算哪些操作。

参数
  • policy_fn_or_list (Callable or List) –

    • 如果提供了策略函数,它应该接受一个 `SelectiveCheckpointContext`、`OpOverload`、op 的 args 和 kwargs,并返回一个 `CheckpointPolicy` 枚举值,指示该 op 的执行是否应被重新计算。

    • 如果提供了一个操作列表,它等同于一个策略函数,该函数为指定的操作返回 `CheckpointPolicy.MUST_SAVE`,为所有其他操作返回 `CheckpointPolicy.PREFER_RECOMPUTE`。

  • allow_cache_entry_mutation (bool, optional) – 默认情况下,如果以确保正确性为目的而突变任何由选择性激活检查点缓存的张量,则会引发错误。如果设置为 `True`,则禁用此检查。

返回

两个上下文管理器的元组。

示例

>>> import functools
>>>
>>> x = torch.rand(10, 10, requires_grad=True)
>>> y = torch.rand(10, 10, requires_grad=True)
>>>
>>> ops_to_save = [
>>>    torch.ops.aten.mm.default,
>>> ]
>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    if op in ops_to_save:
>>>        return CheckpointPolicy.MUST_SAVE
>>>    else:
>>>        return CheckpointPolicy.PREFER_RECOMPUTE
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> # or equivalently
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
>>>
>>> def fn(x, y):
>>>     return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )