训练脚本#
创建于:2021年5月04日 | 最后更新于:2023年2月09日
如果你的训练脚本使用 torch.distributed.launch 工作,它将继续与 torchrun 一起工作,但有以下不同:
无需手动传递
RANK、WORLD_SIZE、MASTER_ADDR和MASTER_PORT。可以提供
rdzv_backend和rdzv_endpoint。对于大多数用户,这将设置为c10d(参见 rendezvous)。默认的rdzv_backend创建一个非弹性的 rendezvous,其中rdzv_endpoint包含主地址。请确保你的脚本中包含
load_checkpoint(path)和save_checkpoint(path)的逻辑。当任意数量的工作进程失败时,我们将使用相同的程序参数重新启动所有工作进程,因此你将丢失直到最近一次检查点之间的所有进度(参见 elastic launch)。use_env标志已被移除。如果你曾通过解析--local-rank选项来解析本地 rank,则需要从环境变量LOCAL_RANK获取本地 rank(例如,int(os.environ["LOCAL_RANK"]))。
下面是一个训练脚本的说明性示例,该脚本在每个 epoch 进行检查点保存,因此在失败时丢失的最大进度相当于一个完整的 epoch 训练。
def main():
args = parse_args(sys.argv[1:])
state = load_checkpoint(args.checkpoint_path)
initialize(state)
# torch.distributed.run ensures that this will work
# by exporting all the env vars needed to initialize the process group
torch.distributed.init_process_group(backend=args.backend)
for i in range(state.epoch, state.total_num_epochs)
for batch in iter(state.dataset)
train(batch, state.model)
state.epoch += 1
save_checkpoint(state)
有关符合 torchelastic 的训练脚本的具体示例,请访问我们的 示例 页面。