评价此页

torch.utils.checkpoint#

创建日期:2025 年 6 月 16 日 | 最近更新日期:2025 年 10 月 29 日

注意

检查点(Checkpointing)是通过在反向传播期间为每个检查点分段重新运行前向传递分段来实现的。这可能会导致像 RNG 状态(随机数生成器状态)这样的持久状态比没有使用检查点时更提前。默认情况下,检查点包含调度 RNG 状态的逻辑,使得使用 RNG 的检查点传递(例如通过 dropout)与非检查点传递相比具有确定性的输出。根据检查点操作的运行时间,存储和恢复 RNG 状态的逻辑可能会导致中度的性能损失。如果不需要与非检查点传递相比的确定性输出,请将 preserve_rng_state=False 传递给 checkpointcheckpoint_sequential,以忽略在每个检查点期间存储和恢复 RNG 状态的操作。

存储逻辑为 CPU 和另一种设备类型(通过 _infer_device_type 从 Tensor 参数中推断设备类型,排除 CPU 张量)保存并恢复 run_fn 的 RNG 状态。如果存在多个设备,设备状态将仅为单一设备类型的设备保存,其余设备将被忽略。因此,如果任何检查点函数涉及随机性,可能会导致梯度不正确。(请注意,如果检测到的设备中包含 CUDA 设备,则会优先选择它;否则,将选择遇到的第一个设备。)如果没有 CPU 张量,将保存并恢复默认设备类型状态(默认值为 cuda,可以通过 DefaultDeviceType 设置为其他设备)。然而,该逻辑无法预见用户是否会在 run_fn 本身内部将张量移动到新设备。因此,如果你在 run_fn 内部将张量移动到新设备(“新”指不属于 [当前设备 + 张量参数的设备] 集合),则永远无法保证与非检查点传递相比的确定性输出。

torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, early_stop=True, **kwargs)[源代码]#

对模型或模型的一部分设置检查点。

激活检查点(Activation checkpointing)是一种以计算换取内存的技术。默认情况下,在前向传递期间计算的张量会一直保持活跃,直到它们在反向传递的梯度计算中使用。为了减少这种内存占用,在传递的 function 中产生的张量不会保持活跃到反向传递。相反,在 args 中传递的任何张量都会保持活跃,未保存的张量将通过在反向传递中根据梯度计算的需要重新调用 function 来重新计算。激活检查点可以应用于模型的任何部分——这有时被描述为对模型的该部分“设置检查点”。

目前有两种可用的检查点实现,由 use_reentrant 参数决定。建议你使用 use_reentrant=False。请参阅下面的注释以了解它们的差异。

警告

如果反向传递期间的 function 调用与前向传递不同(例如由于全局变量),检查点版本可能不等价,这可能会导致引发错误或导致静默的错误梯度。

警告

use_reentrant 参数应显式传递。在 2.9 版本中,如果未传递 use_reentrant,我们将引发异常。如果你使用的是 use_reentrant=True 变体,请参阅下面的注释以了解重要的注意事项和潜在限制。

注意

检查点的重入变体(use_reentrant=True)和非重入变体(use_reentrant=False)在以下方面有所不同:

  • 非重入检查点一旦所有需要的中间激活都已重新计算,就会停止重新计算。此功能默认启用,但可以使用 set_checkpoint_early_stop() 禁用。重入检查点在反向传递期间始终完整地重新计算 function

  • 重入变体在前向传递期间不记录 autograd 图,因为它是在 torch.no_grad() 下运行前向传递的。非重入版本会记录 autograd 图,允许在检查点区域内对图执行反向传播。

  • 重入检查点仅支持不带 inputs 参数的反向传递 torch.autograd.backward() API,而非重入版本支持执行反向传播的所有方式。

  • 对于重入变体,至少一个输入和输出必须具有 requires_grad=True。如果未满足此条件,模型的检查点部分将没有梯度。非重入版本没有此要求。

  • 重入版本不认为嵌套结构(例如自定义对象、列表、字典等)中的张量参与了 autograd,而非重入版本则认为参与了。

  • 重入检查点不支持包含从计算图中分离出的张量的检查点区域,而非重入版本支持。对于重入变体,如果检查点分段包含使用 detach()torch.no_grad() 分离的张量,反向传递将引发错误。这是因为 checkpoint 会使所有输出都需要梯度,当张量在模型中被定义为没有梯度时,这会导致问题。为了避免这种情况,请在 checkpoint 函数之外分离张量。

参数:
  • function – 描述在前向传递中要运行的模型或模型的一部分。它还应该知道如何处理作为元组传递的输入。例如,在 LSTM 中,如果用户传递 (activation, hidden),则 function 应该正确地将第一个输入用作 activation,将第二个输入用作 hidden

  • args – 包含 function 输入的元组

关键字参数:
  • preserve_rng_state (bool, 可选) – 忽略在每个检查点期间存储和恢复 RNG 状态的操作。请注意,在 torch.compile 下,此标志不起作用,我们始终保留 RNG 状态。默认值:True

  • use_reentrant (bool) – 指定是否使用需要重入 autograd 的激活检查点变体。此参数应显式传递。在 2.9 版本中,如果未传递 use_reentrant,我们将引发异常。如果 use_reentrant=Falsecheckpoint 将使用不需要重入 autograd 的实现。这允许 checkpoint 支持额外的功能,例如按预期与 torch.autograd.grad 配合工作,并支持输入到检查点函数的关键字参数。

  • context_fn (Callable, 可选) – 一个返回两个上下文管理器元组的可调用对象。该函数及其重新计算将分别在第一个和第二个上下文管理器下运行。仅当 use_reentrant=False 时才支持此参数。

  • determinism_check (str, 可选) – 指定要执行的确定性检查的字符串。默认设置为 "default",它将重新计算的张量的形状、数据类型和设备与保存的张量进行比较。要关闭此检查,请指定 "none"。目前这些是仅有的两个受支持的值。如果你希望看到更多的确定性检查,请提交 issue。此参数仅在 use_reentrant=False 时受支持;如果 use_reentrant=True,确定性检查始终被禁用。

  • debug (bool, 可选) – 如果为 True,错误消息还将包括原始前向计算以及重新计算期间运行的算子轨迹。仅当 use_reentrant=False 时才支持此参数。

  • early_stop (bool, 可选) – 如果为 True,非重入检查点一旦计算完所有需要的张量就会停止重新计算。如果 use_reentrant=True,则忽略此参数。可以使用 set_checkpoint_early_stop() 上下文管理器全局覆盖。默认值:True

返回:

*args 上运行 function 的输出

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[源代码]#

对顺序模型设置检查点以节省内存。

顺序模型按顺序执行一系列模块/函数。因此,我们可以将此类模型划分为各个分段并对每个分段设置检查点。除最后一个分段外的所有分段都不会存储中间激活。每个检查点分段的输入将被保存,以便在反向传递中重新运行该分段。

警告

use_reentrant 参数应显式传递。在 2.9 版本中,如果未传递 use_reentrant,我们将引发异常。如果你正在使用 use_reentrant=True 变体,请参阅 torch.utils.checkpoint.checkpoint() 了解此变体的重要注意事项和限制。建议你使用 use_reentrant=False

参数:
  • functions – 一个 torch.nn.Sequential 或顺序运行的模块或函数列表(构成模型)。

  • segments – 在模型中创建的块数

  • input – 作为 functions 输入的张量

  • preserve_rng_state (bool, 可选) – 忽略在每个检查点期间存储和恢复 RNG 状态的操作。默认值:True

  • use_reentrant (bool) – 指定是否使用需要重入 autograd 的激活检查点变体。此参数应显式传递。在 2.5 版本中,如果未传递 use_reentrant,我们将引发异常。如果 use_reentrant=Falsecheckpoint 将使用不需要重入 autograd 的实现。这允许 checkpoint 支持额外的功能,例如按预期与 torch.autograd.grad 配合工作,并支持输入到检查点函数的关键字参数。

返回:

*inputs 上顺序运行 functions 的输出

示例

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[源代码]#

设置检查点在运行时是否应打印额外调试信息的上下文管理器。有关更多信息,请参阅 checkpoint()debug 标志。请注意,设置此上下文管理器后,它会覆盖传递给检查点的 debug 值。要遵循本地设置,请向此上下文传递 None

参数:

enabled (bool) – 检查点是否应打印调试信息。默认值为 ‘None’。

class torch.utils.checkpoint.CheckpointPolicy(value)[源代码]#

用于指定反向传播期间检查点策略的枚举。

支持以下策略

  • {MUST,PREFER}_SAVE:操作的输出将在前向传递期间保存,并且在反向传递期间不会重新计算

  • {MUST,PREFER}_RECOMPUTE:操作的输出不会在前向传递期间保存,并将在反向传递期间重新计算

  • {MUST,PREFER}_CPU_OFFLOAD:操作的输出将在前向传递期间保存,卸载到 CPU,并在反向传递期间重新加载到 GPU

使用 MUST_* 而不是 PREFER_* 表示该策略不应被 torch.compile 等其他子系统覆盖。

注意

始终返回 PREFER_RECOMPUTE 的策略函数等同于传统的检查点。

对每个算子都返回 PREFER_SAVE 的策略函数并不等同于不使用检查点。使用此类策略将保存额外的张量,而不限于梯度计算实际需要的张量。

class torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[源代码]#

在选择性检查点期间传递给策略函数的上下文。

此类用于在选择性检查点期间向策略函数传递相关的元数据。元数据包括当前策略函数的调用是否处于重新计算期间。

示例

>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    print(ctx.is_recompute)
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )
torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[源代码]#

辅助函数,用于在激活检查点期间避免重新计算某些算子。

将其与 torch.utils.checkpoint.checkpoint 配合使用,以控制在反向传递期间重新计算哪些操作。

参数:
  • policy_fn_or_list (Callable 或 List) –

    • 如果提供了策略函数,它应该接受一个 SelectiveCheckpointContextOpOverload、算子的参数(args 和 kwargs),并返回一个 CheckpointPolicy 枚举值,指示是否应重新计算该算子的执行。

    • 如果提供了一组操作列表,它等同于为指定操作返回 CheckpointPolicy.MUST_SAVE,并为所有其他操作返回 CheckpointPolicy.PREFER_RECOMPUTE 的策略。

  • allow_cache_entry_mutation (bool, 可选) – 默认情况下,如果选择性激活检查点缓存的任何张量被修改,将引发错误以确保正确性。如果设置为 True,则禁用此检查。

返回:

包含两个上下文管理器的元组。

示例

>>> import functools
>>>
>>> x = torch.rand(10, 10, requires_grad=True)
>>> y = torch.rand(10, 10, requires_grad=True)
>>>
>>> ops_to_save = [
>>>    torch.ops.aten.mm.default,
>>> ]
>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    if op in ops_to_save:
>>>        return CheckpointPolicy.MUST_SAVE
>>>    else:
>>>        return CheckpointPolicy.PREFER_RECOMPUTE
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> # or equivalently
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
>>>
>>> def fn(x, y):
>>>     return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )