分布式自动求导设计#
创建于:2019年11月12日 | 最后更新于:2021年9月3日
本笔记将介绍分布式自动求导的详细设计并深入探讨其内部机制。在继续之前,请确保您熟悉自动求导机制和分布式RPC框架。
背景#
假设您有两个节点,并且有一个非常简单的模型在这两个节点上进行了分区。这可以使用torch.distributed.rpc实现,如下所示:
import torch
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
分布式自动求导的主要目的是,对于此类分布式模型,能够使用我们计算出的loss运行反向传播,并为所有需要梯度的张量记录适当的梯度。
前向传播期间的自动求导记录#
PyTorch 在前向传播期间构建自动求导图,该图用于执行反向传播。更多详情请参阅自动求导如何编码历史记录。
对于分布式自动求导,我们需要在前向传播期间跟踪所有RPC,以确保反向传播被正确执行。为此,当执行RPC时,我们将send和recv函数附加到自动求导图上。
send函数附加到RPC的源端,其输出边指向RPC输入张量的自动求导函数。在反向传播期间,此函数的输入是从目标端作为相应recv函数的输出接收的。recv函数附加到RPC的目标端,其输入通过在目标端使用输入张量执行的操作符获取。在反向传播期间,此函数的输出梯度被发送到源节点,交给相应的send函数。每个
send-recv对被分配一个全局唯一的autograd_message_id,以唯一标识该对。这在反向传播期间用于在远程节点上查找对应的函数。对于RRef,每当我们调用
torch.distributed.rpc.RRef.to_here()时,我们都会为所涉及的张量附加一个适当的send-recv对。
例如,上面示例的自动求导图如下所示(为简化起见,t5.sum()已排除):
分布式自动求导上下文#
每次使用分布式自动求导的前向传播和反向传播都会被分配一个唯一的torch.distributed.autograd.context,并且此上下文具有全局唯一的autograd_context_id。此上下文在每个节点上按需创建。
此上下文用于以下目的:
多个运行分布式反向传播的节点可能会在同一个张量上累积梯度,导致在有机会运行优化器之前,张量的
.grad字段中包含来自各种分布式反向传播的梯度。这类似于在本地多次调用torch.autograd.backward()。为了提供一种方法来分离每次反向传播的梯度,梯度会针对每次反向传播累积在torch.distributed.autograd.context中。在前向传播期间,我们将每个自动求导过程的
send和recv函数存储在此上下文中。这确保了我们持有对自动求导图中相应节点的引用,以使其保持活跃。此外,在反向传播期间,查找适当的send和recv函数也很容易。通常,我们还使用此上下文为每个分布式自动求导过程存储一些元数据。
从用户的角度来看,自动求导上下文设置如下:
import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
loss = model.forward()
dist_autograd.backward(context_id, loss)
重要的是要注意,您模型的前向传播必须在分布式自动求导上下文管理器中调用,因为需要一个有效的上下文来确保所有send和recv函数都被正确存储,以便在所有参与节点上运行反向传播。
分布式反向传播#
在本节中,我们将概述在分布式反向传播期间准确计算依赖关系的挑战,并描述几种(带有权衡)算法,说明如何执行分布式反向传播。
计算依赖关系#
考虑以下在单机上运行的代码片段:
import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()
上述代码的自动求导图如下所示:
自动求导引擎在反向传播中执行的第一步是计算自动求导图中每个节点的依赖关系数量。这有助于自动求导引擎知道图中何时某个节点已准备好执行。add(1)和mul(0)括号中的数字表示依赖关系的数量。如您所见,这意味着在反向传播期间,add节点需要1个输入,而mul节点不需要任何输入(换句话说,不需要执行)。本地自动求导引擎通过从根节点(在此例中为d)遍历图来计算这些依赖关系。
自动求导图中某些节点可能不会在反向传播中执行这一事实给分布式自动求导带来了挑战。考虑以下使用RPC的代码片段:
import torch
import torch.distributed.rpc as rpc
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()
上述代码的关联自动求导图将是:
计算此分布式自动求导图的依赖关系更具挑战性,并且需要一定的开销(无论是计算还是网络通信)。
对于性能敏感的应用程序,我们可以通过假设每个send和recv函数都是反向传播的有效部分(大多数应用程序不会执行未使用的RPC)来避免大量开销。这简化了分布式自动求导算法并提高了效率,但代价是应用程序需要了解其局限性。此算法称为FAST模式算法,并将在下面详细描述。
在一般情况下,并非每个send和recv函数都必须是反向传播的有效部分。为了解决这个问题,我们提出了SMART模式算法,将在后面的章节中描述。请注意,目前仅实现了FAST模式算法。
FAST模式算法#
此算法的关键假设是,当运行反向传播时,每个send函数都具有1个依赖关系。换句话说,我们假设将通过RPC从另一个节点接收梯度。
该算法如下:
我们从拥有反向传播根节点的工作节点开始(所有根节点必须是本地的)。
查找当前分布式自动求导上下文的所有
send函数。从提供的根节点和我们检索到的所有
send函数开始,在本地计算依赖关系。计算完依赖关系后,使用提供的根节点启动本地自动求导引擎。
当自动求导引擎执行
recv函数时,recv函数通过RPC将输入梯度发送到相应的工作节点。每个recv函数都知道目标工作节点ID,因为它作为前向传播的一部分被记录。recv函数还会将autograd_context_id和autograd_message_id发送到远程主机。当远程主机收到此请求时,我们使用
autograd_context_id和autograd_message_id查找相应的send函数。如果这是工作节点首次收到给定
autograd_context_id的请求,它将按照上面第1-3点所述在本地计算依赖关系。在第6点中检索到的
send函数随后被排队,以便在该工作节点的本地自动求导引擎上执行。最后,我们不将梯度累积在张量的
.grad字段上,而是针对每个分布式自动求导上下文单独累积梯度。梯度存储在一个Dict[Tensor, Tensor]中,这基本上是将张量映射到其相关梯度的映射,并且可以使用get_gradients()API检索此映射。
举例来说,包含分布式自动求导的完整代码如下:
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
# Run the backward pass.
dist_autograd.backward(context_id, [loss])
# Retrieve the gradients from the context.
dist_autograd.get_gradients(context_id)
带有依赖关系的分布式自动求导图如下所示(为简化起见,t5.sum()已排除):
将FAST模式算法应用于上述示例,如下所示:
在
Worker 0上,我们从根节点loss和send1开始计算依赖关系。结果是send1被标记为依赖关系1,Worker 0上的mul被标记为依赖关系1。现在,我们启动
Worker 0上的本地自动求导引擎。我们首先执行mul函数,将其输出作为t4的梯度累积到自动求导上下文中。然后,我们执行recv2,它将梯度发送到Worker 1。由于这是
Worker 1首次收到有关此反向传播的信息,它会开始计算依赖关系,并适当地标记send2、add和recv1的依赖关系。接下来,我们将
send2加入Worker 1的本地自动求导引擎的队列中,该引擎随后会执行add和recv1。当
recv1执行时,它将梯度发送到Worker 0。由于
Worker 0已经计算了此反向传播的依赖关系,它只是在本地将send1入队并执行。最后,
t1、t2和t4的梯度累积在分布式自动求导上下文中。
SMART模式算法#
此算法的完整细节仍在开发中,但大致思路可以参考RFC中的“Distributed Autograd Algorithm Smart mode”部分。
分布式优化器#
DistributedOptimizer的运作方式如下:
接受一个远程参数列表(
RRef)进行优化。这些参数也可以是包装在本地RRef中的本地参数。接受一个
Optimizer类作为本地优化器,用于在所有不同的RRef所有者上运行。分布式优化器在每个工作节点上创建本地
Optimizer的一个实例,并持有对它们的RRef。当调用
torch.distributed.optim.DistributedOptimizer.step()时,分布式优化器使用RPC在相应的远程工作节点上远程执行所有本地优化器。必须向torch.distributed.optim.DistributedOptimizer.step()提供一个分布式自动求导context_id作为输入。本地优化器使用此ID来应用存储在相应上下文中的梯度。如果多个并发的分布式优化器正在更新同一个工作节点上的相同参数,这些更新会通过锁进行序列化。
简单端到端示例#
综合来看,以下是一个使用分布式自动求导和分布式优化器的简单端到端示例。如果代码放在名为“dist_autograd_simple.py”的文件中,可以使用命令MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py运行。
import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
def random_tensor():
return torch.rand((3, 3), requires_grad=True)
def _run_process(rank, dst_rank, world_size):
name = "worker{}".format(rank)
dst_name = "worker{}".format(dst_rank)
# Initialize RPC.
rpc.init_rpc(
name=name,
rank=rank,
world_size=world_size
)
# Use a distributed autograd context.
with dist_autograd.context() as context_id:
# Forward pass (create references on remote nodes).
rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here()
# Backward pass (run distributed autograd).
dist_autograd.backward(context_id, [loss.sum()])
# Build DistributedOptimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
# Run the distributed optimizer step.
dist_optim.step(context_id)
def run_process(rank, world_size):
dst_rank = (rank + 1) % world_size
_run_process(rank, dst_rank, world_size)
rpc.shutdown()
if __name__ == '__main__':
# Run world_size workers
world_size = 2
mp.spawn(run_process, args=(world_size,), nprocs=world_size)