管理器#
此模块实现了管理完整容错训练循环的管理器。
管理器负责管理整个训练循环,与 Lighthouse 服务器通信以确定法定人数,在恢复时重新配置 ProcessGroups 和恢复检查点状态。
这使用了包装类来包装标准的 PyTorch Optimizer 和 Module 类,以提供容错能力。这些包装器的目的是在用户模型代码和训练循环中进行最小的更改以增加容错能力。
这旨在与标准的 PyTorch DistributedDataParallel 模块和 Hybrid FSDP 一起使用。
- class torchft.manager.Manager(pg: ProcessGroup, load_state_dict: Optional[Callable[[T], None]], state_dict: Optional[Callable[[], T]], min_replica_size: int, use_async_quorum: bool = True, timeout: timedelta = datetime.timedelta(seconds=60), quorum_timeout: timedelta = datetime.timedelta(seconds=60), connect_timeout: timedelta = datetime.timedelta(seconds=60), rank: Optional[int] = None, world_size: Optional[int] = None, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, store_addr: Optional[str] = None, store_port: Optional[int] = None, lighthouse_addr: Optional[str] = None, replica_id: Optional[str] = None, port: Optional[int] = None, hostname: str = 'pkrvmjbmru5nbw0', heartbeat_interval: timedelta = datetime.timedelta(microseconds=100000), checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None, init_sync: bool = True, max_retries: Optional[int] = None, quorum_retries: int = 0)[source]#
基类:
object
管理器管理完整的容错训练循环。
这要求由 store_addr 和 store_port 或 MASTER_ADDR 和 MASTER_PORT 环境变量指定的 TCPStore 在创建此管理器之前启动。如果使用较新版本的 torchelastic,则情况已是如此。否则,应在创建此管理器之前通过 torch.distributed.init_process_group 启动它。
注意:在保存周期性检查点时,必须同时保存和恢复管理器的 state_dict,以避免同步问题。
- allreduce(tensor: Tensor, should_quantize: bool = False) Work [source]#
对张量进行容错 allreduce 并返回一个 Future,当张量就绪时将完成。
这将自动将张量按 1 / world_size 缩放。
如果在 allreduce 过程中发生错误
Future 将在没有错误的情况下完成,而是异步跟踪。
第一次错误后,所有后续调用都将成为 noops 并立即返回。
张量在被使用之前必须归零,因为它可能会损坏。
- 参数
tensor – 要 allreduce 的张量
should_quantize – 在通信之前是否应量化张量
- 返回
一个 Future,当张量 allreduce 完成时将得到完成。
- batches_committed() int [source]#
获取所有步和副本提交的总批次数。参与 2 步的 5 个副本是 10 个批次,但根据批次大小,可能多于 10 个示例。
此数字在 .step() 上递增
- 返回
已提交的总批次数
- errored() Optional[ExceptionWithTraceback] [source]#
获取是否发生错误。
- 返回
错误,如果未发生错误则为 None。
- load_state_dict(state_dict: Dict[str, int]) None [source]#
从先前的检查点加载 state_dict。
这将恢复步数和内部元数据。
- 参数
state_dict – 要加载的 state_dict
- num_participants() int [source]#
获取当前法定人数中的参与者数量。
这是参与当前步骤的副本数量。
这将阻塞异步法定人数,如果它尚未就绪。
- 返回
当前法定人数中的参与者数量
- participating_rank() Optional[int] [source]#
获取当前法定人数的副本组秩。在副本组内的所有秩上,这将是相同的。
如果此副本组未参与当前法定人数,则此项为 None。
这将阻塞异步法定人数,如果它尚未就绪。
- 返回
当前法定人数的秩
- register_state_dict_fn(key: str, load_state_dict: Callable[[T], None], state_dict: Callable[[], T]) None [source]#
- report_error(e: Exception) None [source]#
向管理器报告错误。
这将导致管理器跳过当前步骤,并在下一步进行重新配置。
当发生导致梯度损坏需要丢弃的错误时,应调用此函数。
- set_state_dict_fns(load_state_dict: Callable[[T], None], state_dict: Callable[[], T]) None [source]#
- should_commit(timeout: Optional[timedelta] = None) bool [source]#
注意
我们建议使用
torchft.optim.OptimizerWrapper
而不是直接调用此函数。必须在反向传播完成但优化器步进之前调用。
仅当此函数返回 True 时才应步进优化器。
必须在副本组内的所有工作进程上调用此函数。它使用集体通信来确保副本内的所有工作进程返回相同的值。如果在任何工作进程上发生错误,所有工作进程都将返回 False。不同的副本组可能返回不同的值。
每步最多应调用一次此函数。
如果设置了 max_retries 并且 should_commit 连续失败达到该次数,则此方法将引发 RuntimeError 以防止无限失败循环。
- 返回
如果应步进优化器,则返回 True,否则返回 False
- 引发
RuntimeError – 如果 should_commit 连续失败 max_retries 次且 max_retries 已设置
- start_quorum(allow_heal: bool = True, shrink_only: bool = False, timeout: Optional[timedelta] = None) None [source]#
注意
我们建议使用
torchft.optim.OptimizerWrapper
而不是直接调用此函数。计算新的法定人数(可能异步地)并使管理器准备好进行新步骤。
最佳实践是在每个步骤的前向传播之前调用此函数,因为计算法定人数可能需要一些时间。
- 参数
allow_heal – (实验性)是否允许在步骤开始时进行修复。如果设置了 allow_heal,管理器将在返回之前尝试同步修复,或在任何网络调用之前异步修复。所有副本都必须传递相同的值给 allow_heal。
timeout – 法定人数就绪的超时时间,如果为 None,则使用管理器超时,恢复操作将使用管理器超时。