评价此页

分布式自动微分设计#

创建于:2019 年 11 月 12 日 | 最后更新于:2021 年 9 月 3 日

本文档将介绍分布式自动微分的详细设计,并深入探讨其内部机制。在继续阅读之前,请确保您已熟悉 自动微分机制分布式 RPC 框架

背景#

假设您有两个节点,并且一个非常简单的模型被划分到这两个节点上。这可以使用 torch.distributed.rpc 实现,如下所示:

import torch
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)

# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)

# Compute some loss.
loss = t5.sum()

分布式自动微分的主要动机是为了能够对这种分布式模型运行反向传播,使用我们已经计算出的 loss,并为所有需要梯度的张量记录适当的梯度。

前向传播过程中的自动微分记录#

PyTorch 在前向传播过程中构建自动微分图,该图用于执行反向传播。有关更多详细信息,请参阅 自动微分如何编码历史记录

对于分布式自动微分,我们需要在前向传播过程中跟踪所有 RPC 调用,以确保反向传播能够正确执行。为此,当我们执行 RPC 时,我们会将 sendrecv 函数附加到自动微分图中。

  • send 函数附加到 RPC 的源端,其输出边指向 RPC 输入张量的自动微分函数。在反向传播期间,此函数的输入是从目标节点通过相应的 recv 函数接收的。

  • recv 函数附加到 RPC 的目标端,其输入通过目标端使用输入张量执行的操作来检索。此函数的输出梯度在反向传播期间作为输出发送到源节点,传递给相应的 send 函数。

  • 每个 send-recv 对都被分配一个全局唯一的 autograd_message_id,以唯一标识该对。这在反向传播期间查找远程节点上的相应函数时非常有用。

  • 对于 RRef,每当我们调用 torch.distributed.rpc.RRef.to_here() 时,我们都会附加一个适当的 send-recv 对来处理涉及的张量。

例如,我们上面示例的自动微分图将如下所示(为简化起见,已排除 t5.sum()):

../_images/send_recv_functions.png

分布式自动微分上下文#

每个使用分布式自动微分的前向和反向传播都会被分配一个唯一的 torch.distributed.autograd.context,并且该上下文有一个全局唯一的 autograd_context_id。这个上下文会在需要时在每个节点上创建。

此上下文具有以下目的:

  1. 多个节点运行分布式反向传播可能会在同一个张量上累积梯度,因此在我们有机会运行优化器之前,张量的 .grad 字段将包含来自各种分布式反向传播的梯度。这类似于多次在本地调用 torch.autograd.backward()。为了提供一种区分每次反向传播的梯度的方法,梯度被累积在每个反向传播的 torch.distributed.autograd.context 中。

  2. 在前向传播期间,我们将每个自动微分过程的 sendrecv 函数存储在此上下文中。这确保我们持有自动微分图中相应节点的引用,以保持其活动。此外,这使得在反向传播期间更容易查找相应的 sendrecv 函数。

  3. 通常,我们还使用此上下文为每个分布式自动微分过程存储一些元数据。


从用户的角度来看,自动微分上下文的设置如下:

import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
  loss = model.forward()
  dist_autograd.backward(context_id, loss)

需要注意的是,您的模型的前向传播必须在分布式自动微分上下文管理器中调用,因为需要一个有效的上下文才能确保所有 sendrecv 函数被正确存储,以便在所有参与节点上运行反向传播。

分布式反向传播#

本节将概述在分布式反向传播过程中准确计算依赖关系的挑战,并描述几种(带有权衡的)执行分布式反向传播的算法。

计算依赖关系#

考虑在单台机器上运行的以下代码片段:

import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()

上述代码的自动微分图将如下所示:

../_images/local_dependencies.png

作为反向传播一部分的自动微分引擎执行的第一步是计算自动微分图中每个节点的依赖项数量。这有助于自动微分引擎知道何时一个图节点已准备好执行。对于 add(1)mul(0),方括号中的数字表示依赖项的数量。如您所见,这意味着在反向传播期间,add 节点需要 1 个输入,而 mul 节点不需要任何输入(换句话说,不需要执行)。局部自动微分引擎通过从根节点(在本例中为 d)遍历图来计算这些依赖项。

自动微分图中某些节点可能不会在反向传播中执行的事实给分布式自动微分带来了挑战。考虑这段使用 RPC 的代码:

import torch
import torch.distributed.rpc as rpc

a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)

d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()

上述代码的关联自动微分图将是:

../_images/distributed_dependencies.png

计算这个分布式自动微分图的依赖关系更具挑战性,需要一些开销(无论是计算还是网络通信)。

对于性能敏感的应用,我们可以避免大量开销,假设每个 sendrecv 函数在反向传播中是有效的(大多数应用不会执行不使用的 RPC)。这简化了分布式自动微分算法,并且效率更高,但代价是应用程序需要了解其局限性。这个算法称为 FAST 模式算法,下面将详细描述。

在一般情况下,并非每个 sendrecv 函数在反向传播中都必须是有效的。为了解决这个问题,我们提出了一个 SMART 模式算法,将在后续部分进行描述。请注意,目前仅实现了 FAST 模式算法。

FAST 模式算法#

此算法的关键假设是,在运行反向传播时,每个 send 函数都有一个依赖项为 1。换句话说,我们假设我们将从另一个节点通过 RPC 接收到梯度。

算法如下:

  1. 我们从拥有反向传播根节点的 worker 开始(所有根节点必须是本地的)。

  2. 查找当前 分布式自动微分上下文 的所有 send 函数。

  3. 从提供的根节点和我们检索到的所有 send 函数开始,在本地计算依赖关系。

  4. 计算完依赖关系后,使用提供的根节点启动本地自动微分引擎。

  5. 当自动微分引擎执行 recv 函数时,recv 函数通过 RPC 将输入梯度发送到相应的 worker。每个 recv 函数都知道目标 worker ID,因为它是在前向传播过程中记录的。recv 函数还会将 autograd_context_idautograd_message_id 发送到远程主机。

  6. 当在远程主机上收到此请求时,我们使用 autograd_context_idautograd_message_id 来查找相应的 send 函数。

  7. 如果这是 worker 第一次收到给定 autograd_context_id 的请求,它将按照上述第 1-3 点在本地计算依赖关系。

  8. 在 6. 中检索到的 send 函数随后将被加入到该 worker 的本地自动微分引擎的执行队列中。

  9. 最后,我们不是将梯度累积到张量的 .grad 字段中,而是为每个 分布式自动微分上下文 分别累积梯度。梯度存储在 Dict[Tensor, Tensor] 中,它本质上是一个从张量到其关联梯度的映射,并且可以使用 get_gradients() API 检索此映射。


作为一个例子,下面是包含分布式自动微分的完整代码:

import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:

# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
  t1 = torch.rand((3, 3), requires_grad=True)
  t2 = torch.rand((3, 3), requires_grad=True)

  # Perform some computation remotely.
  t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

  # Perform some computation locally based on remote result.
  t4 = torch.rand((3, 3), requires_grad=True)
  t5 = torch.mul(t3, t4)

  # Compute some loss.
  loss = t5.sum()

  # Run the backward pass.
  dist_autograd.backward(context_id, [loss])

  # Retrieve the gradients from the context.
  dist_autograd.get_gradients(context_id)

带有依赖关系的分布式自动微分图将如下所示(为简化起见,已排除 t5.sum()):

../_images/distributed_dependencies_computed.png

应用于上述示例的 FAST 模式算法 将如下所示:

  1. Worker 0 上,我们从根节点 losssend1 开始计算依赖关系。结果是 send1 被标记为具有 1 的依赖性,而 Worker 0 上的 mul 被标记为具有 1 的依赖性。

  2. 现在,我们在 Worker 0 上启动本地自动微分引擎。我们首先执行 mul 函数,将其输出累积在自动微分上下文中作为 t4 的梯度。然后,我们执行 recv2,它将梯度发送到 Worker 1

  3. 由于这是 Worker 1 第一次收到有关此反向传播的消息,因此它开始计算依赖关系,并相应地标记 send2addrecv1 的依赖关系。

  4. 接下来,我们将 send2 加入到 Worker 1 的本地自动微分引擎的队列中,该引擎随后执行 addrecv1

  5. 当执行 recv1 时,它将梯度发送到 Worker 0

  6. 由于 Worker 0 已经为这个反向传播计算了依赖关系,它只是在本地加入并执行 send1

  7. 最后,t1t2t4 的梯度将累积在 分布式自动微分上下文 中。

SMART 模式算法#

该算法的完整细节仍在开发中,但有关其大致思想,您可以参考 RFC 中的 **分布式自动微分算法 SMART 模式** 部分。

分布式优化器#

DistributedOptimizer 的工作原理如下:

  1. 接受一个要优化的远程参数列表(RRef)。这些也可以是包装在本地 RRef 中的本地参数。

  2. 接受一个 Optimizer 类作为本地优化器,在所有不同的 RRef 所有者上运行。

  3. 分布式优化器在每个 worker 节点上创建一个本地 Optimizer 实例,并持有指向它们的 RRef

  4. 当调用 torch.distributed.optim.DistributedOptimizer.step() 时,分布式优化器使用 RPC 来远程执行所有适当的远程 worker 上的本地优化器。必须将分布式自动微分 context_id 作为输入提供给 torch.distributed.optim.DistributedOptimizer.step()。这被本地优化器用来应用存储在相应上下文中的梯度。

  5. 如果多个并发的分布式优化器正在更新 worker 上的同一组参数,这些更新将通过锁进行序列化。

简单的端到端示例#

将所有内容整合在一起,下面是一个使用分布式自动微分和分布式优化器的简单端到端示例。如果将代码放在名为“dist_autograd_simple.py”的文件中,可以使用命令 MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py 来运行。

import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer

def random_tensor():
    return torch.rand((3, 3), requires_grad=True)

def _run_process(rank, dst_rank, world_size):
    name = "worker{}".format(rank)
    dst_name = "worker{}".format(dst_rank)

    # Initialize RPC.
    rpc.init_rpc(
        name=name,
        rank=rank,
        world_size=world_size
    )

    # Use a distributed autograd context.
    with dist_autograd.context() as context_id:
        # Forward pass (create references on remote nodes).
        rref1 = rpc.remote(dst_name, random_tensor)
        rref2 = rpc.remote(dst_name, random_tensor)
        loss = rref1.to_here() + rref2.to_here()

        # Backward pass (run distributed autograd).
        dist_autograd.backward(context_id, [loss.sum()])

        # Build DistributedOptimizer.
        dist_optim = DistributedOptimizer(
        optim.SGD,
        [rref1, rref2],
        lr=0.05,
        )

        # Run the distributed optimizer step.
        dist_optim.step(context_id)

def run_process(rank, world_size):
    dst_rank = (rank + 1) % world_size
    _run_process(rank, dst_rank, world_size)
    rpc.shutdown()

if __name__ == '__main__':
  # Run world_size workers
  world_size = 2
  mp.spawn(run_process, args=(world_size,), nprocs=world_size)