torch.utils.checkpoint#
创建于:2025 年 6 月 16 日 | 最后更新于:2025 年 6 月 16 日
注意
检查点通过在反向传播期间为每个检查点段重新运行前向传播段来实现。这可能导致持久状态(如 RNG 状态)比不使用检查点时更超前。默认情况下,检查点包括调整 RNG 状态的逻辑,以便使用 RNG(例如通过 dropout)的检查点传递的输出与非检查点传递的输出相比是确定性的。保存和恢复 RNG 状态的逻辑可能会根据检查点操作的运行时性能产生中等程度的影响。如果不需要与非检查点传递相比的确定性输出,则将 preserve_rng_state=False
提供给 checkpoint
或 checkpoint_sequential
以省略在每个检查点期间保存和恢复 RNG 状态。
存储逻辑会保存并恢复 CPU 和另一种设备类型(通过 _infer_device_type
从排除 CPU 张量的 Tensor 参数推断设备类型)的 RNG 状态到 run_fn
。如果存在多个设备,则仅会保存单一设备类型的设备状态,其余设备将被忽略。因此,如果任何检查点函数涉及随机性,这可能会导致不正确的梯度。(请注意,如果 CUDA 设备在检测到的设备中,它将被优先考虑;否则,将选择遇到的第一个设备。)如果没有 CPU 张量,则会保存和恢复默认设备类型状态(默认值为 cuda
,可以通过 DefaultDeviceType
设置为其他设备)。然而,该逻辑无法预测用户是否会在 run_fn
内部将张量移动到新设备(“新”指不属于 [当前设备 + Tensor 参数的设备] 集合的设备)。因此,如果您在 run_fn
内部将张量移动到新设备,则无法保证与非检查点传递相比的确定性输出。
- torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[source]#
检查模型或模型的一部分。
激活检查点是一种以计算换取内存的技术。检查点区域中的前向计算会省略保存反向所需的张量,并在反向传播期间重新计算它们,而不是将反向所需的张量保留到反向传播中的梯度计算中使用。激活检查点可以应用于模型的任何部分。
目前有两种可用的检查点实现,由
use_reentrant
参数确定。建议您使用use_reentrant=False
。请参阅下面的说明以了解它们的差异。警告
如果在反向传播期间
function
的调用与前向传播不同,例如由于全局变量,则检查点版本可能不相等,这可能会导致错误或导致静默的不正确梯度。警告
应显式传递
use_reentrant
参数。在版本 2.4 中,如果未传递use_reentrant
,我们将引发异常。如果您使用的是use_reentrant=True
变体,请参阅下面的说明以了解重要的注意事项和潜在限制。注意
检查点的可重入变体(
use_reentrant=True
)和检查点的不可重入变体(use_reentrant=False
)在以下方面有所不同:不可重入检查点会在所有需要的中间激活重新计算完成后立即停止重新计算。此功能默认启用,但可以使用
set_checkpoint_early_stop()
禁用。可重入检查点始终在反向传播期间完整地重新计算function
。可重入变体在前向传播期间不记录自动微分图,因为它在前向传播中在
torch.no_grad()
下运行。不可重入版本确实记录自动微分图,允许在检查点区域内对图进行反向传播。可重入检查点仅支持不带 inputs 参数的
torch.autograd.backward()
API 进行反向传播,而不可重入版本支持所有进行反向传播的方式。对于可重入变体,至少一个输入和输出必须具有
requires_grad=True
。如果不满足此条件,则模型的检查点部分将不具有梯度。不可重入版本没有此要求。可重入版本不认为嵌套结构(例如,自定义对象、列表、字典等)中的张量参与自动微分,而不可重入版本则认为参与。
可重入检查点不支持计算图中包含分离张量的检查点区域,而不可重入版本支持。对于可重入变体,如果检查点段包含使用
detach()
或torch.no_grad()
分离的张量,则反向传播将引发错误。这是因为checkpoint
使所有输出需要梯度,这在张量在模型中定义为无梯度时会导致问题。为避免这种情况,请在checkpoint
函数之外分离张量。
- 参数
function – 描述模型或模型部分在前向传播中要运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在 LSTM 中,如果用户传递
(activation, hidden)
,function
应该正确地将第一个输入用作activation
,将第二个输入用作hidden
preserve_rng_state (bool, 可选) – 在每个检查点期间省略存储和恢复 RNG 状态。请注意,在 torch.compile 下,此标志不起作用,我们始终保留 RNG 状态。默认值:
True
use_reentrant (bool) – 指定是否使用需要可重入自动微分的激活检查点变体。此参数应显式传递。在版本 2.5 中,如果未传递
use_reentrant
,我们将引发异常。如果use_reentrant=False
,checkpoint
将使用不需要可重入自动微分的实现。这允许checkpoint
支持附加功能,例如与torch.autograd.grad
按预期工作以及支持将关键字参数输入到检查点函数。context_fn (Callable, 可选) – 一个可调用对象,返回一个包含两个上下文管理器的元组。函数及其重新计算将分别在第一个和第二个上下文管理器下运行。此参数仅在
use_reentrant=False
时受支持。determinism_check (str, 可选) – 指定要执行的确定性检查的字符串。默认情况下,它设置为
"default"
,它比较重新计算的张量与保存的张量的形状、dtype 和设备。要关闭此检查,请指定"none"
。目前,这些是仅有的两个受支持的值。如果您希望看到更多确定性检查,请提交一个问题。此参数仅在use_reentrant=False
时受支持;如果use_reentrant=True
,则确定性检查始终禁用。debug (bool, 可选) – 如果为
True
,错误消息还将包括原始前向计算以及重新计算期间运行的操作员的跟踪。此参数仅在use_reentrant=False
时受支持。args – 包含
function
输入的元组
- 返回
在
*args
上运行function
的输出
- torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source]#
对顺序模型进行检查点,以节省内存。
顺序模型按顺序(依次)执行模块/函数列表。因此,我们可以将这样的模型分成不同的段并对每个段进行检查点。除了最后一段之外,所有段都不会存储中间激活。每个检查点段的输入将被保存,以便在反向传播中重新运行该段。
警告
参数
use_reentrant
应该显式传递。在 2.4 版本中,如果未传递use_reentrant
,我们将抛出异常。如果您正在使用use_reentrant=True` 变体,请参阅 :func:`~torch.utils.checkpoint.checkpoint` 以了解此变体的重要注意事项和限制。建议您使用 ``use_reentrant=False
。- 参数
functions – 一个
torch.nn.Sequential
或要顺序运行的模块或函数列表(构成模型)。segments – 在模型中创建的块数
input –
functions
的输入张量preserve_rng_state (bool, optional) – 在每个检查点期间省略存储和恢复 RNG 状态。默认值:
True
use_reentrant (bool) – 指定是否使用需要可重入自动微分的激活检查点变体。此参数应显式传递。在版本 2.5 中,如果未传递
use_reentrant
,我们将引发异常。如果use_reentrant=False
,checkpoint
将使用不需要可重入自动微分的实现。这允许checkpoint
支持附加功能,例如与torch.autograd.grad
按预期工作以及支持将关键字参数输入到检查点函数。
- 返回
在
*inputs
上顺序运行functions
的输出
示例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)
- torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]#
上下文管理器,用于设置检查点在运行时是否应打印额外的调试信息。有关更多信息,请参阅
checkpoint()
的debug
标志。请注意,当设置时,此上下文管理器会覆盖传递给检查点的debug
值。要委托给本地设置,请将None
传递给此上下文。- 参数
enabled (bool) – 检查点是否应打印调试信息。默认值为“None”。
- class torch.utils.checkpoint.CheckpointPolicy(value)[source]#
用于在反向传播期间指定检查点策略的枚举。
支持以下策略
{MUST,PREFER}_SAVE
:操作的输出将在前向传播期间保存,并且在反向传播期间不会重新计算{MUST,PREFER}_RECOMPUTE
:操作的输出在前向传播期间不会保存,并且将在反向传播期间重新计算
使用
MUST_*
而不是PREFER_*
来表示该策略不应被其他子系统(如 torch.compile)覆盖。注意
一个总是返回
PREFER_RECOMPUTE
的策略函数等同于香草检查点。一个每个操作都返回
PREFER_SAVE
的策略函数不等同于不使用检查点。使用这样的策略会保存额外的张量,不限于梯度计算实际需要的张量。
- class torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[source]#
在选择性检查点期间传递给策略函数的上下文。
此类别用于在选择性检查点期间将相关元数据传递给策略函数。元数据包括当前策略函数的调用是否在重新计算期间。
示例
>>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> print(ctx.is_recompute) >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )
- torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[source]#
在激活检查点期间避免重新计算某些操作的辅助函数。
与 torch.utils.checkpoint.checkpoint 一起使用,以控制在反向传播期间重新计算哪些操作。
- 参数
policy_fn_or_list (Callable or List) –
如果提供了策略函数,它应该接受一个
SelectiveCheckpointContext
、OpOverload
、操作的 args 和 kwargs,并返回一个CheckpointPolicy
枚举值,指示是否应该重新计算操作的执行。如果提供了操作列表,则等同于对指定操作返回 CheckpointPolicy.MUST_SAVE,对所有其他操作返回 CheckpointPolicy.PREFER_RECOMPUTE 的策略。
allow_cache_entry_mutation (bool, optional) – 默认情况下,如果选择性激活检查点缓存的任何张量被修改,则会引发错误,以确保正确性。如果设置为 True,则禁用此检查。
- 返回
一个包含两个上下文管理器的元组。
示例
>>> import functools >>> >>> x = torch.rand(10, 10, requires_grad=True) >>> y = torch.rand(10, 10, requires_grad=True) >>> >>> ops_to_save = [ >>> torch.ops.aten.mm.default, >>> ] >>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> if op in ops_to_save: >>> return CheckpointPolicy.MUST_SAVE >>> else: >>> return CheckpointPolicy.PREFER_RECOMPUTE >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> # or equivalently >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) >>> >>> def fn(x, y): >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )