评价此页

使用 ZeroRedundancyOptimizer 对优化器状态进行分片#

创建日期:2021年2月26日 | 最后更新:2021年10月20日 | 最后验证:未验证

在本实践中,您将学习

要求#

什么是 ZeroRedundancyOptimizer#

ZeroRedundancyOptimizer 的理念源于 DeepSpeed/ZeRO 项目Marian,旨在将优化器状态在分布式数据并行进程间进行分片,从而降低每个进程的内存占用。在《分布式数据并行入门》教程中,我们展示了如何使用 DistributedDataParallel (DDP) 来训练模型。在该教程中,每个进程都保存一份完整的优化器副本。由于 DDP 已经在反向传播中同步了梯度,因此所有优化器副本在每次迭代时都会对相同的参数和梯度值进行操作,这也是 DDP 保持模型副本状态一致的方式。通常,优化器还会维护一些局部状态。例如,Adam 优化器会使用每个参数对应的 exp_avgexp_avg_sq 状态。因此,Adam 优化器的内存消耗至少是模型大小的两倍。基于这一观察,我们可以通过在 DDP 进程间对优化器状态进行分片来减少优化器的内存占用。更具体地说,每个 DDP 进程中的优化器实例不再为所有参数创建状态,而是仅为模型参数的一个分片保留优化器状态。优化器的 step() 函数仅更新其分片内的参数,然后将其更新后的参数广播给所有其他对等 DDP 进程,从而确保所有模型副本仍然处于相同的状态。

如何使用 ZeroRedundancyOptimizer#

下面的代码演示了如何使用 ZeroRedundancyOptimizer。大部分代码与《分布式数据并行说明》中介绍的简单 DDP 示例类似。主要区别在于 example 函数中的 if-else 子句,它封装了优化器的构建过程,在 ZeroRedundancyOptimizerAdam 优化器之间进行切换。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

def print_peak_memory(prefix, device):
    if device == 0:
        print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def example(rank, world_size, use_zero):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
    print_peak_memory("Max memory allocated after creating local model", rank)

    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    print_peak_memory("Max memory allocated after creating DDP", rank)

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    if use_zero:
        optimizer = ZeroRedundancyOptimizer(
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=0.01
        )
    else:
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

    # forward pass
    outputs = ddp_model(torch.randn(20, 2000).to(rank))
    labels = torch.randn(20, 2000).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()

    # update parameters
    print_peak_memory("Max memory allocated before optimizer step()", rank)
    optimizer.step()
    print_peak_memory("Max memory allocated after optimizer step()", rank)

    print(f"params sum is: {sum(model.parameters()).sum()}")



def main():
    world_size = 2
    print("=== Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, True),
        nprocs=world_size,
        join=True)

    print("=== Not Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, False),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

输出如下所示。当在 Adam 中启用 ZeroRedundancyOptimizer 时,优化器 step() 的峰值内存消耗是原生 Adam 消耗的一半。这符合我们的预期,因为我们将 Adam 优化器状态在两个进程间进行了分片。输出还显示,使用 ZeroRedundancyOptimizer 后,模型参数在经过一次迭代后仍然得到相同的值(无论是否使用 ZeroRedundancyOptimizer,参数总和都是相同的)。

=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875