评价此页

torch.utils.checkpoint#

创建于: 2025年6月16日 | 最后更新于: 2025年10月29日

注意

Checkpointing(检查点)是通过在反向传播期间为每个检查点片段重新运行前向传播来实现的。这可能会导致像RNG状态这样的持久状态比不进行检查点时更向前推进。默认情况下,检查点包含用于调整RNG状态的逻辑,以便使用RNG(例如通过dropout)进行检查点的前向传播与不进行检查点的传播相比具有确定性的输出。在每次检查点期间,存储和恢复RNG状态的逻辑可能会产生一定的性能损失,具体取决于被检查点操作的运行时长。如果不需要与非检查点传播相比的确定性输出,请为 `checkpoint` 或 `checkpoint_sequential` 提供 `preserve_rng_state=False`,以在每次检查点期间省略RNG状态的存储和恢复。

存储逻辑会为CPU和另一种设备类型(通过 `_infer_device_type` 从排除CPU张量的张量参数推断设备类型)保存和恢复RNG状态到 `run_fn`。如果存在多种设备,设备状态将仅为一种设备类型保存,而其余设备将被忽略。因此,如果任何检查点函数涉及随机性,这可能会导致梯度不正确。(请注意,如果检测到的设备中包含CUDA设备,则会优先选择CUDA;否则,将选择遇到的第一个设备。)如果没有CPU张量,则会保存和恢复默认设备类型状态(默认值为 `cuda`,可以通过 `DefaultDeviceType` 设置为其他设备)。但是,该逻辑无法预测用户是否在 `run_fn` 内部将张量移动到新设备。因此,如果在 `run_fn` 中将张量移动到新设备(“新”表示不属于[当前设备+张量参数的设备]集合的设备),则永远不能保证与非检查点传播相比的确定性输出。

torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, early_stop=True, **kwargs)[source]#

检查点模型或模型的一部分。

激活检查点是一种用计算换内存的技术。默认情况下,在前向传播期间计算的张量会一直保留,直到在反向传播的梯度计算中使用它们。为了减少内存使用,传递给 `function` 的张量不会保留到反向传播。取而代之的是,传递给 `args` 的张量会被保留,而在反向传播期间,通过重新调用 `function` 来根据需要重新计算未保存的张量以进行梯度计算。激活检查点可以应用于模型的任何部分——这有时被称为“检查点”该模型的部分。

目前有两种可用的检查点实现,由 `use_reentrant` 参数决定。建议使用 `use_reentrant=False`。请参考下面的说明讨论它们的区别。

警告

如果在反向传播期间调用 `function` 与前向传播不同(例如,由于全局变量),则检查点版本可能不相等,可能导致错误被引发或导致静默错误的梯度。

警告

应显式传递 `use_reentrant` 参数。在2.9版本中,如果未传递 `use_reentrant`,我们将引发异常。如果您正在使用 `use_reentrant=True` 变体,请参阅下面的说明以了解重要的考虑因素和潜在限制。

注意

检查点的重入变体(`use_reentrant=True`)和非重入变体(`use_reentrant=False`)在以下方面有所不同

  • 非重入检查点在重新计算完所有必需的中间激活后立即停止重计算。此功能默认启用,但可以通过 `set_checkpoint_early_stop()` 禁用。重入检查点在反向传播期间始终完全重新计算 `function`。

  • 重入变体在前向传播期间不记录autograd图,因为它在前向传播时以 `torch.no_grad()` 运行。非重入版本会记录autograd图,允许在检查点区域内对图执行反向传播。

  • 重入检查点仅支持不带 `inputs` 参数的反向传播API `torch.autograd.backward()`,而非重入版本支持所有反向传播方式。

  • 重入变体至少需要一个输入和输出具有 `requires_grad=True`。如果未满足此条件,模型的检查点部分将没有梯度。非重入版本没有此要求。

  • 重入版本不将嵌套结构(例如,自定义对象、列表、字典等)中的张量视为参与autograd,而非重入版本则会。

  • 重入检查点不支持与计算图分离的张量(detached tensors)的检查点区域,而非重入版本支持。对于重入变体,如果检查点片段包含通过 `detach()` 或 `torch.no_grad()` 分离的张量,则反向传播将引发错误。这是因为 `checkpoint` 使所有输出都要求梯度,而当模型中某个张量被定义为没有梯度时,这会导致问题。为避免这种情况,请在 `checkpoint` 函数外部分离张量。

参数:
  • function – 描述模型或模型一部分在前向传播中需要运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户传递 `(activation, hidden)`,`function` 应该正确地将第一个输入用作 `activation`,第二个输入用作 `hidden`。

  • args – 包含传递给 `function` 的输入的元组。

关键字参数:
  • preserve_rng_state (bool, optional) – 在每次检查点期间省略RNG状态的存储和恢复。注意,在使用 `torch.compile` 时,此标志无效,我们始终保留RNG状态。默认为 `True`。

  • use_reentrant (bool) – 指定是否使用需要重入autograd的激活检查点变体。此参数应显式传递。在2.9版本中,如果未传递 `use_reentrant`,我们将引发异常。如果 `use_reentrant=False`,`checkpoint` 将使用一个不需要重入autograd的实现。这使得 `checkpoint` 可以支持其他功能,例如与 `torch.autograd.grad` 正常工作,并支持传递给检查点函数的关键字参数。

  • context_fn (Callable, optional) – 返回两个上下文管理器元组的可调用对象。函数及其重计算将在第一个和第二个上下文管理器下分别运行。此参数仅在使用 `use_reentrant=False` 时支持。

  • determinism_check (str, optional) – 一个字符串,指定要执行的确定性检查。默认设置为 `"default"`,它比较重计算张量的形状、数据类型和设备与已保存张量的形状、数据类型和设备。要关闭此检查,请指定 `"none"`。目前只有这两个支持的值。如果您希望看到更多的确定性检查,请提交一个issue。此参数仅在使用 `use_reentrant=False` 时支持;如果 `use_reentrant=True`,则始终禁用确定性检查。

  • debug (bool, optional) – 如果为 `True`,错误消息还将包括原始前向计算和重计算期间运行的操作的跟踪。此参数仅在使用 `use_reentrant=False` 时支持。

  • early_stop (bool, optional) – 如果为 `True`,非重入检查点将在计算完所有需要的张量后停止重计算。如果 `use_reentrant=True`,此参数将被忽略。可以使用 `set_checkpoint_early_stop()` 上下文管理器全局覆盖。默认为 `True`。

返回:

在 `args` 上运行 `function` 的输出。

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source]#

检查点顺序模型以节省内存。

顺序模型按顺序(顺序地)执行模块/函数的列表。因此,我们可以将这样的模型分成多个片段并检查点每个片段。除最后一个片段外,所有片段都不会存储中间激活。每个检查点片段的输入将被保存,以便在反向传播中重新运行该片段。

警告

应显式传递 `use_reentrant` 参数。在2.9版本中,如果未传递 `use_reentrant`,我们将引发异常。如果您正在使用 `use_reentrant=True` 变体,请参阅 :func:`~torch.utils.checkpoint.checkpoint` 以了解此变体的注意事项和限制。建议使用 `use_reentrant=False`。

参数:
  • functions – 一个 `torch.nn.Sequential` 或由模块或函数组成的列表(构成模型),按顺序执行。

  • segments – 要在模型中创建的块的数量。

  • input – 输入到 `functions` 的一个张量。

  • preserve_rng_state (bool, optional) – 在每次检查点期间省略RNG状态的存储和恢复。默认为 `True`。

  • use_reentrant (bool) – 指定是否使用需要重入autograd的激活检查点变体。此参数应显式传递。在2.5版本中,如果未传递 `use_reentrant`,我们将引发异常。如果 `use_reentrant=False`,`checkpoint` 将使用一个不需要重入autograd的实现。这使得 `checkpoint` 可以支持其他功能,例如与 `torch.autograd.grad` 正常工作,并支持传递给检查点函数的关键字参数。

返回:

在 `inputs` 上顺序运行 `functions` 的输出。

示例

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]#

上下文管理器,用于设置检查点在运行时是否应打印额外的调试信息。有关更多信息,请参阅 `checkpoint()` 的 `debug` 标志。请注意,当设置此选项时,此上下文管理器将覆盖传递给 `checkpoint` 的 `debug` 值。要遵循本地设置,请为此上下文传递 `None`。

参数:

enabled (bool) – 检查点是否应打印调试信息。默认为“None”。

class torch.utils.checkpoint.CheckpointPolicy(value)[source]#

枚举,用于指定反向传播期间的检查点策略。

支持以下策略:

  • {MUST,PREFER}_SAVE:操作的输出将在前向传播期间保存,并且不会在反向传播期间重新计算。

  • {MUST,PREFER}_RECOMPUTE:操作的输出不会在前向传播期间保存,并且将在反向传播期间重新计算。

使用 `MUST_*` 而不是 `PREFER_*` 来指示策略不应被其他子系统(如 `torch.compile`)覆盖。

注意

一个始终返回 `PREFER_RECOMPUTE` 的策略函数等同于普通的检查点。

一个为每个操作返回 `PREFER_SAVE` 的策略函数不等于不使用检查点。使用这样的策略会保存额外的张量,而不仅仅是那些实际用于梯度计算的张量。

class torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[source]#

在选择性检查点期间传递给策略函数的上下文。

此类用于在选择性检查点期间将相关元数据传递给策略函数。元数据包括当前调用策略函数是在重新计算期间还是不在。

示例

>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    print(ctx.is_recompute)
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )
torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[source]#

帮助避免在激活检查点期间重新计算某些操作。

与 `torch.utils.checkpoint.checkpoint` 一起使用此函数来控制在反向传播期间重新计算哪些操作。

参数:
  • policy_fn_or_list (Callable or List) –

    • 如果提供了策略函数,它应该接受一个 `SelectiveCheckpointContext`、`OpOverload`、操作的 args 和 kwargs,并返回一个 `CheckpointPolicy` 枚举值,指示操作的执行是否应该被重新计算。

    • 如果提供了一个操作列表,它等同于一个策略函数为指定的操作返回 `CheckpointPolicy.MUST_SAVE`,而为所有其他操作返回 `CheckpointPolicy.PREFER_RECOMPUTE`。

  • allow_cache_entry_mutation (bool, optional) – 默认情况下,如果由选择性激活检查点缓存的任何张量被修改,将引发错误以确保正确性。如果设置为 `True`,则禁用此检查。

返回:

两个上下文管理器的元组。

示例

>>> import functools
>>>
>>> x = torch.rand(10, 10, requires_grad=True)
>>> y = torch.rand(10, 10, requires_grad=True)
>>>
>>> ops_to_save = [
>>>    torch.ops.aten.mm.default,
>>> ]
>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    if op in ops_to_save:
>>>        return CheckpointPolicy.MUST_SAVE
>>>    else:
>>>        return CheckpointPolicy.PREFER_RECOMPUTE
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> # or equivalently
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
>>>
>>> def fn(x, y):
>>>     return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )