简介 || 什么是 DDP || 单节点多 GPU 训练 || 容错 || 多节点训练 || minGPT 训练
使用 torchrun 进行容错分布式训练#
创建于:2022 年 9 月 27 日 | 最后更新:2024 年 11 月 12 日 | 最后验证:2024 年 11 月 5 日
请跟随下面的视频或在 youtube 上观看。
在分布式训练中,单个进程的失败可能会中断整个训练任务。由于这里的故障敏感性可能更高,因此使您的训练脚本具有鲁棒性尤为重要。您可能还希望您的训练任务是弹性的,例如,计算资源可以在任务过程中动态地加入和离开。
PyTorch 提供了一个名为 torchrun 的实用程序,它提供了容错和弹性训练功能。当发生故障时,torchrun 会记录错误并尝试从训练任务的最后一个保存的“快照”自动重启所有进程。
快照保存的内容不仅仅是模型状态;它还可以包含有关已运行的 epoch 数量、优化器状态或训练任务连续性所需的任何其他有状态属性的详细信息。
为什么使用 torchrun#
torchrun 处理分布式训练的细节,因此您无需这样做。例如,
您无需设置环境变量或显式传递
rank和world_size;torchrun会分配这些以及其他几个 环境变量。无需在脚本中调用
mp.spawn;您只需要一个通用的main()入口点,然后使用torchrun启动脚本。这样,同一个脚本就可以在非分布式、单节点和多节点设置中运行。从最后一个保存的训练快照平稳地重启训练。
平稳重启#
为了实现平稳重启,您应该像这样构建您的训练脚本
def main():
load_snapshot(snapshot_path)
initialize()
train()
def train():
for batch in iter(dataset):
train_step(batch)
if should_checkpoint:
save_snapshot(snapshot_path)
如果发生故障,torchrun 将终止所有进程并重新启动它们。每个进程入口点首先加载并初始化最后一个保存的快照,然后从那里继续训练。因此,在任何故障发生时,您只会丢失自最后一个保存的快照以来的训练进度。
在弹性训练中,每当发生任何成员资格更改(添加或删除节点)时,torchrun 将终止并在可用设备上生成进程。拥有此结构可确保您的训练任务能够继续进行,而无需手动干预。
进程组初始化#
torchrun会自动分配RANK和WORLD_SIZE,以及 其他环境变量
- def ddp_setup(rank, world_size):
+ def ddp_setup():
- """
- Args:
- rank: Unique identifier of each process
- world_size: Total number of processes
- """
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = "12355"
- init_process_group(backend="nccl", rank=rank, world_size=world_size)
+ init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
使用 torchrun 提供的环境变量#
- self.gpu_id = gpu_id
+ self.gpu_id = int(os.environ["LOCAL_RANK"])
保存和加载快照#
定期将所有相关信息存储在快照中,使我们的训练任务能够在中断后无缝恢复。
+ def _save_snapshot(self, epoch):
+ snapshot = {}
+ snapshot["MODEL_STATE"] = self.model.module.state_dict()
+ snapshot["EPOCHS_RUN"] = epoch
+ torch.save(snapshot, "snapshot.pt")
+ print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")
+ def _load_snapshot(self, snapshot_path):
+ snapshot = torch.load(snapshot_path)
+ self.model.load_state_dict(snapshot["MODEL_STATE"])
+ self.epochs_run = snapshot["EPOCHS_RUN"]
+ print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
在 Trainer 构造函数中加载快照#
当恢复中断的训练任务时,您的脚本将首先尝试加载快照以继续训练。
class Trainer:
def __init__(self, snapshot_path, ...):
...
+ if os.path.exists(snapshot_path):
+ self._load_snapshot(snapshot_path)
...
恢复训练#
训练可以从最后一个运行的 epoch 继续,而无需从头开始。
def train(self, max_epochs: int):
- for epoch in range(max_epochs):
+ for epoch in range(self.epochs_run, max_epochs):
self._run_epoch(epoch)
运行脚本#
就像运行非多进程脚本一样调用您的入口点函数;torchrun 会自动生成进程。
if __name__ == "__main__":
import sys
total_epochs = int(sys.argv[1])
save_every = int(sys.argv[2])
- world_size = torch.cuda.device_count()
- mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
+ main(save_every, total_epochs)
- python multigpu.py 50 10
+ torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10
进一步阅读#
使用 DDP 进行多节点训练(本系列的下一篇教程)
使用 DDP 进行多 GPU 训练(本系列上一篇教程)