评价此页

LocalSGD#

此模块实现了 LocalSGD 的容错版本及相关方法。

class torchft.local_sgd.DiLoCo(manager: Manager, model_fragments: List[Module], inner_optimizer: Optimizer, outer_optimizer: torch.optim.optimizer.Optimizer | list[torch.optim.optimizer.Optimizer], sync_every: int, backup_device: Optional[device] = None, pin_memory: bool = True, use_bucketization: bool = False, bucket_cap_mb: Optional[int] = None, should_quantize: bool = False, fragment_sync_delay: int = 0, fragment_update_alpha: float = 0.0)[来源]#

基类: object

DiLoCo 是 LocalSGD 的一个子类,它重写了同步机制,用于对伪梯度(先前全局权重与当前局部权重之差)进行平均和同步。

该类实现了一个更通用的 DiLoco 版本,即流式 DiLoCo,它在不同步骤同步伪梯度的片段。

此算法需要权重的备份副本。默认情况下,这些备份副本存储在 CPU 内存中。如果在 DiLoCo 步骤中发生任何错误,该步骤将被丢弃,模型参数将恢复到 DiLoCo 上一次同步时的状态。

DiLoCo 论文: https://arxiv.org/pdf/2311.08105 流式 DiLoCo 论文: https://arxiv.org/pdf/2501.18512

class torchft.local_sgd.LocalSGD(manager: Manager, model: Module, optimizer: Optimizer, sync_every: int)[来源]#

基类: object

LocalSGD 是一个上下文管理器,实现了 https://arxiv.org/pdf/1805.09767 中描述的算法。

它将使用 torchft Manager 以容错方式定期同步模型参数。allreduce 参数将在 optimizer.step 调用后的每 sync_every 步发生。

torchft 仲裁将在 sync_every 步开始时计算。如果发生任何错误,或在同步之间有 worker 失败,sync_every 步将被丢弃,并在下一步计算新的仲裁。

如果在异步模式下运行,对于加入的 worker,前 sync_every 步将被丢弃,因为模型在此期间将进行恢复。在使用同步模式时,将在第一步之前恢复检查点。

sync() None[来源]#

同步并平均管理器上的模型权重。

torchft.local_sgd.extract_local_tensor(t: Tensor) Tensor[来源]#

返回输入张量的一个克隆版本。如果输入张量是 DTensor,则提取并克隆其本地表示。