结合使用 Distributed DataParallel 和 Distributed RPC Framework#
创建日期:2020 年 7 月 28 日 | 最后更新:2023 年 6 月 6 日 | 最后验证:未验证
作者:Pritam Damania 和 Yi Wang
注意
在 github 上查看和编辑本教程。
本教程使用一个简单的示例来演示如何结合使用 DistributedDataParallel (DDP) 和 Distributed RPC framework,以结合分布式数据并行和分布式模型并行来训练一个简单的模型。示例的源代码可以在 这里 找到。
之前的教程 分布式数据并行入门 和 分布式 RPC 框架入门 分别介绍了如何进行分布式数据并行和分布式模型并行训练。然而,存在几种您可能希望结合这两种技术的训练范式。例如:
如果我们有一个具有稀疏部分(大型嵌入表)和稠密部分(FC 层)的模型,我们可能希望将嵌入表放在参数服务器上,并使用 DistributedDataParallel 将 FC 层复制到多个训练器上。 Distributed RPC framework 可用于在参数服务器上执行嵌入查找。
启用 PipeDream 论文中描述的混合并行。我们可以使用 Distributed RPC framework 将模型阶段管道化到多个工作节点上,并在需要时使用 DistributedDataParallel 复制每个阶段。
在本教程中,我们将介绍上面提到的第一种情况。我们的设置中总共有 4 个工作节点,如下所示:
1 个 Master,负责在参数服务器上创建嵌入表 (nn.EmbeddingBag)。Master 还负责驱动两个训练器上的训练循环。
1 个 Parameter Server,它基本上将嵌入表保存在内存中,并响应来自 Master 和 Trainers 的 RPC。
2 个 Trainers,它们存储一个 FC 层 (nn.Linear),该层使用 DistributedDataParallel 在它们之间进行复制。Trainers 还负责执行前向传播、反向传播和优化器步骤。
整个训练过程执行如下:
Master 创建一个 RemoteModule,该模块在 Parameter Server 上保存嵌入表。
然后,Master 启动训练器上的训练循环,并将远程模块传递给训练器。
Trainers 创建一个名为
HybridModel的模型,该模型首先使用 Master 提供的远程模块执行嵌入查找,然后执行 DDP 包装的 FC 层。Trainer 执行模型的正向传播,并使用损失通过 Distributed Autograd 执行反向传播。
在反向传播过程中,首先计算 FC 层的梯度,并通过 DDP 中的 allreduce 同步到所有 Trainer。
接下来,Distributed Autograd 将梯度传播到参数服务器,并在那里更新嵌入表的梯度。
最后,使用 Distributed Optimizer 来更新所有参数。
注意
如果您正在结合使用 DDP 和 RPC,在反向传播时应始终使用 Distributed Autograd。
现在,让我们详细了解每个部分。首先,在进行任何训练之前,我们需要设置好所有工作节点。我们创建 4 个进程,其中 rank 0 和 1 是我们的 Trainers,rank 2 是 Master,rank 3 是 Parameter Server。
我们使用 TCP init_method 在所有 4 个工作节点上初始化 RPC 框架。RPC 初始化完成后,Master 使用 RemoteModule 在 Parameter Server 上创建一个包含 EmbeddingBag 层的远程模块。然后,Master 循环遍历每个 Trainer,并通过使用 rpc_async 调用每个 Trainer 上的 _run_trainer 来启动训练循环。最后,Master 在退出前等待所有训练完成。
Trainers 首先使用 init_process_group 为 DDP 初始化一个 ProcessGroup,world_size=2(表示两个 Trainer)。接下来,它们使用 TCP init_method 初始化 RPC 框架。请注意,RPC 初始化和 ProcessGroup 初始化使用的端口是不同的。这是为了避免两个框架初始化之间的端口冲突。初始化完成后,Trainers 只需等待来自 Master 的 _run_trainer RPC。
Parameter Server 只初始化 RPC 框架,并等待来自 Trainers 和 Master 的 RPC。
def run_worker(rank, world_size):
r"""
A wrapper function that initializes RPC, calls the function, and shuts down
RPC.
"""
# We need to use different port numbers in TCP init_method for init_rpc and
# init_process_group to avoid port conflicts.
rpc_backend_options = TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = "tcp://:29501"
# Rank 2 is master, 3 is ps and 0 and 1 are trainers.
if rank == 2:
rpc.init_rpc(
"master",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
remote_emb_module = RemoteModule(
"ps",
torch.nn.EmbeddingBag,
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
kwargs={"mode": "sum"},
)
# Run the training loop on trainers.
futs = []
for trainer_rank in [0, 1]:
trainer_name = "trainer{}".format(trainer_rank)
fut = rpc.rpc_async(
trainer_name, _run_trainer, args=(remote_emb_module, trainer_rank)
)
futs.append(fut)
# Wait for all training to finish.
for fut in futs:
fut.wait()
elif rank <= 1:
# Initialize process group for Distributed DataParallel on trainers.
dist.init_process_group(
backend="gloo", rank=rank, world_size=2, init_method="tcp://:29500"
)
# Initialize RPC.
trainer_name = "trainer{}".format(rank)
rpc.init_rpc(
trainer_name,
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
# Trainer just waits for RPCs from master.
else:
rpc.init_rpc(
"ps",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
# parameter server do nothing
pass
# block until all rpcs finish
rpc.shutdown()
if __name__ == "__main__":
# 2 trainers, 1 parameter server, 1 master.
world_size = 4
mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)
在讨论 Trainer 的细节之前,让我们先介绍一下 Trainer 使用的 HybridModel。如下所述,HybridModel 使用一个在参数服务器上保存嵌入表的远程模块 (remote_emb_module) 和用于 DDP 的 device 进行初始化。模型的初始化将一个 nn.Linear 层包装在 DDP 中,以在所有 Trainer 之间复制和同步该层。
模型的正向传播方法非常直接。它使用 RemoteModule 的 forward 方法在参数服务器上执行嵌入查找,并将其输出传递给 FC 层。
class HybridModel(torch.nn.Module):
r"""
The model consists of a sparse part and a dense part.
1) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.
2) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.
This remote model can get a Remote Reference to the embedding table on the parameter server.
"""
def __init__(self, remote_emb_module, device):
super(HybridModel, self).__init__()
self.remote_emb_module = remote_emb_module
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
self.device = device
def forward(self, indices, offsets):
emb_lookup = self.remote_emb_module.forward(indices, offsets)
return self.fc(emb_lookup.cuda(self.device))
接下来,我们看看 Trainer 上的设置。Trainer 首先使用一个在参数服务器上保存嵌入表的远程模块及其自己的 rank 来创建上述 HybridModel。
现在,我们需要检索一个 RRefs 列表,其中包含我们希望使用 DistributedOptimizer 优化的所有参数。要从参数服务器检索嵌入表的参数,我们可以调用 RemoteModule 的 remote_parameters,它会遍历嵌入表的所有参数并返回一个 RRefs 列表。Trainer 通过 RPC 调用参数服务器上的此方法,以接收所需的参数的 RRefs 列表。由于 DistributedOptimizer 始终接受一个 RRefs 列表作为需要优化的参数,因此我们还需要为 FC 层的本地参数创建 RRefs。这是通过遍历 model.fc.parameters(),为每个参数创建一个 RRef,并将其追加到从 remote_parameters() 返回的列表中来完成的。请注意,我们不能使用 model.parameters(),因为这会递归调用 model.remote_emb_module.parameters(),而这不受 RemoteModule 支持。
最后,我们使用所有 RRefs 创建 DistributedOptimizer,并定义一个 CrossEntropyLoss 函数。
def _run_trainer(remote_emb_module, rank):
r"""
Each trainer runs a forward pass which involves an embedding lookup on the
parameter server and running nn.Linear locally. During the backward pass,
DDP is responsible for aggregating the gradients for the dense part
(nn.Linear) and distributed autograd ensures gradients updates are
propagated to the parameter server.
"""
# Setup the model.
model = HybridModel(remote_emb_module, rank)
# Retrieve all model parameters as rrefs for DistributedOptimizer.
# Retrieve parameters for embedding table.
model_parameter_rrefs = model.remote_emb_module.remote_parameters()
# model.fc.parameters() only includes local parameters.
# NOTE: Cannot call model.parameters() here,
# because this will call remote_emb_module.parameters(),
# which supports remote_parameters() but not parameters().
for param in model.fc.parameters():
model_parameter_rrefs.append(RRef(param))
# Setup distributed optimizer
opt = DistributedOptimizer(
optim.SGD,
model_parameter_rrefs,
lr=0.05,
)
criterion = torch.nn.CrossEntropyLoss()
现在我们可以介绍在每个 Trainer 上运行的主训练循环了。get_next_batch 只是一个用于生成随机输入和目标以供训练的辅助函数。我们对多个 epoch 运行训练循环,对于每个 batch:
为 Distributed Autograd 设置一个 Distributed Autograd Context。
运行模型的正向传播并检索其输出。
使用损失函数根据我们的输出和目标计算损失。
使用 Distributed Autograd 来执行分布式反向传播,以损失为依据。
最后,运行一个 Distributed Optimizer 步骤来优化所有参数。
def get_next_batch(rank):
for _ in range(10):
num_indices = random.randint(20, 50)
indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)
# Generate offsets.
offsets = []
start = 0
batch_size = 0
while start < num_indices:
offsets.append(start)
start += random.randint(1, 10)
batch_size += 1
offsets_tensor = torch.LongTensor(offsets)
target = torch.LongTensor(batch_size).random_(8).cuda(rank)
yield indices, offsets_tensor, target
# Train for 100 epochs
for epoch in range(100):
# create distributed autograd context
for indices, offsets, target in get_next_batch(rank):
with dist_autograd.context() as context_id:
output = model(indices, offsets)
loss = criterion(output, target)
# Run distributed backward pass
dist_autograd.backward(context_id, [loss])
# Tun distributed optimizer
opt.step(context_id)
# Not necessary to zero grads as each iteration creates a different
# distributed autograd context which hosts different grads
print("Training done for epoch {}".format(epoch))
整个示例的源代码可以在 这里 找到。