评价此页

使用 Join 上下文管理器进行不均匀输入分布式训练#

创建时间:2021 年 8 月 4 日 | 最后更新:2023 年 1 月 9 日 | 最后验证:2024 年 11 月 5 日

作者Andrew Gu

注意

editGitHub 上查看和编辑此教程。

注意

Join 作为原型功能引入 PyTorch 1.10。此 API 可能会发生更改。

在本教程中,您将看到

  • Join 上下文管理器的概述。

  • 一个如何将上下文管理器与 DistributedDataParallel 一起使用的示例。

  • 一个如何将上下文管理器与 DistributedDataParallelZeroRedundancyOptimizer 一起使用的示例。

  • 一个传递关键字参数给上下文管理器的示例。

  • 深入了解 Join 上下文管理器的工作原理。

  • 一个演示如何使玩具类与上下文管理器兼容的示例。

要求#

什么是 Join#

分布式数据并行入门 - 基本用例 中,您了解了使用 DistributedDataParallel 进行数据并行训练的通用框架。这会在每次反向传播时隐式调度 all-reduces 以在各个 rank 之间同步梯度。这类 集合通信 需要进程组中的所有 rank 参与,因此如果一个 rank 的输入较少,其他 rank 将会挂起或报错(取决于后端)。更广泛地说,这个问题对于任何执行按迭代同步集合通信的类都存在。

Join 是一个上下文管理器,用于包裹您的按迭代训练循环,以方便处理不均匀输入的情况。该上下文管理器允许输入提前耗尽的 rank(即提前加入)来模拟尚未加入的 rank 所执行的集合通信。通信被模拟的方式由钩子指定。

JoinDistributedDataParallel 一起使用#

PyTorch 的 DistributedDataParallel 可以直接与 Join 上下文管理器配合使用。以下是一个示例用法

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

这将产生以下输出(其中 rank 0 和 rank 1 的 print() 可能顺序任意):

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

注意

在引入此通用 Join 上下文管理器之前,DistributedDataParallel 提供了自己的 join() 上下文管理器。在上面的示例中,使用 with Join([model]): 等同于使用 with model.join():。现有的 DistributedDataParallel.join() 的一个限制是它不允许多个参与类,例如同时使用 DistributedDataParallelZeroRedundancyOptimizer

JoinDistributedDataParallelZeroRedundancyOptimizer 一起使用#

Join 上下文管理器不仅可以与单个类一起使用,还可以与多个类一起使用。PyTorch 的 ZeroRedundancyOptimizer 也兼容该上下文管理器,因此在这里,我们研究如何修改之前的示例以同时使用 DistributedDataParallelZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

这将产生与之前相同的输出。值得注意的变化是将 ZeroRedundancyOptimizer 实例另外传递给了 Join()

传递关键字参数#

类可以提供关键字参数,在运行时修改它们在上下文管理器中的行为。例如,DistributedDataParallel 提供了一个参数 divide_by_initial_world_size,该参数决定梯度是除以初始 world size 还是除以有效 world size(即非加入 rank 的数量)。这类关键字参数可以直接传递到上下文管理器。

with Join([model, optim], divide_by_initial_world_size=False):
    for input in inputs:
        ...

警告

传递给上下文管理器的关键字参数在所有参与类之间共享。这不应该成为限制,因为我们不期望出现多个 Joinable 需要相同参数的不同设置的情况。尽管如此,这仍然是需要注意的一点。

Join 是如何工作的?#

现在我们已经看到了一些使用 Join 上下文管理器的初步示例,让我们深入了解它是如何工作的。这将提供对其完整功能的更深入的了解,并为您准备好使自己的自定义类兼容。在这里,我们将介绍 Join 类以及支持类 JoinableJoinHook

Joinable#

首先,与 Join 上下文管理器兼容的类必须继承自抽象基类 Joinable。特别是,Joinable 必须实现

  • join_hook(self, **kwargs) -> JoinHook

这会返回 JoinableJoinHook 实例,该实例确定加入的进程如何模拟 Joinable 在每次迭代中执行的集合通信。

  • join_device(self) -> torch.device

这会返回一个由 Join 上下文管理器用于执行集合通信的设备,例如 torch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

这会返回由 Join 上下文管理器用于执行集合通信的进程组。

特别是,join_devicejoin_process_group 是必需的属性,以确保上下文管理器可以调度加入和未加入进程之间的集合通信。一种用法是使用 all-reduce 来计算每次迭代中未加入进程的数量。另一种用法是用于实现 throw_on_early_termination=True 所需的机制,我们将在下面进一步解释。

DistributedDataParallelZeroRedundancyOptimizer 已经继承自 Joinable 并实现了上述方法,这就是为什么我们可以在之前的示例中直接使用它们。

Joinable 类应确保调用 Joinable 构造函数,因为它会初始化一个 JoinConfig 实例,该实例由上下文管理器在内部使用以确保正确性。这将作为字段 _join_config 保存在每个 Joinable 中。

JoinHook#

接下来,让我们分解 JoinHook 类。 JoinHook 为上下文管理器提供了两个入口点。

  • main_hook(self) -> None

在存在尚未加入的 rank 时,每个加入的 rank 会反复调用此钩子。它旨在模拟 Joinable 在每个训练迭代中执行的集合通信(例如,在一个前向传播、反向传播和优化器步骤中)。

  • post_hook(self, is_last_joiner: bool) -> None

在所有 rank 加入后,将调用此钩子。它会传递一个额外的布尔参数 is_last_joiner,该参数指示该 rank 是否是最后加入的 rank 之一。该参数可能有助于同步。

为了给出这些钩子可能是什么样子的具体示例,提供的 ZeroRedundancyOptimizer main hook 正常执行优化器步骤,因为加入的 rank 仍然负责更新和同步其参数分片,而提供的 DistributedDataParallel post-hook 从最后加入的 rank 之一广播最终更新的模型,以确保所有 rank 都相同。

Join#

最后,让我们看看这些如何融入 Join 类本身。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我们在前面的示例中看到的,构造函数接受一个参与训练循环的 Joinable 列表。这些应该是每个迭代中执行集合通信的类。

enable 是一个布尔值,如果已知不会出现不均匀输入,则可以将其设置为 False,在这种情况下,上下文管理器将变得空洞,类似于 contextlib.nullcontext()。这也可能会禁用参与的 Joinable 中的 join 相关计算。

throw_on_early_termination 是一个布尔值,可以设置为 True,以便在检测到不均匀输入时,每个 rank 都会引发一个异常。这对于不符合上下文管理器要求的情况很有用,最常见的是当 DistributedDataParallel 与具有 SyncBatchNorm 层的模型一起使用时。在这种情况下,应将此参数设置为 True,以便应用程序逻辑可以捕获异常并确定如何继续。

  • 核心逻辑发生在 __exit__() 方法中,该方法在存在未加入 rank 时循环,调用每个 Joinable 的 main hook,然后在所有 rank 加入后,调用它们的 post hook。main hooks 和 post-hooks 都是按照 Joinable 传递的顺序进行迭代的。

  • 上下文管理器需要来自未加入进程的心跳。因此,每个 Joinable 类都应该在其每次迭代的集合通信(即其 all-reduce)之前调用 Join.notify_join_context()。上下文管理器将确保只有传入的第一个 Joinable 实际发送心跳。

警告

如上关于 throw_on_early_termination 所述,Join 上下文管理器与某些类组合不兼容。 JoinableJoinHook s 必须是可序列化的,因为每个钩子在继续下一个之前都会被完全执行。换句话说,两个钩子不能重叠。此外,目前 main hooks 和 post-hooks 都是按照相同的确定性顺序进行迭代的。如果这似乎是一个主要限制,我们可能会修改 API 以允许自定义顺序。

使玩具类与 Join 一起工作#

由于上一节介绍了一些概念,让我们通过一个玩具示例来实践它们。在这里,我们将实现一个类,该类计算其 rank 加入之前所有 rank 中看到的输入数量。这应该能为您提供如何使自己的类与 Join 上下文管理器兼容的基本思路。

具体来说,以下代码让每个 rank 打印出(1)它加入之前所有 rank 中看到的输入数量,以及(2)所有 rank 中的总输入数量。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由于 rank 0 看到 5 个输入,rank 1 看到 6 个输入,这会产生输出

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

一些需要强调的关键点

  • Counter 实例每个迭代执行一次 all-reduce,因此 main hook 也执行一次 all-reduce 来模拟它。

  • Counter 类在其 __call__() 方法的开头调用 Join.notify_join_context(),因为那是其按迭代集合通信(即其 all-reduce)之前的位置。

  • is_last_joiner 参数用于确定 post-hook 中的广播源。

  • 我们将 sync_max_count 关键字参数传递给上下文管理器,然后将其转发给 Counter 的 join hook。