评价此页

使用 Join 上下文管理器进行不均匀输入分布式训练#

创建于:2021 年 8 月 4 日 | 最后更新:2025 年 9 月 3 日 | 最后验证:2024 年 11 月 5 日

作者Andrew Gu

注意

editgithub 上查看和编辑本教程。

注意

Join 作为一项原型功能在 PyTorch 1.10 中引入。此 API 可能会发生更改。

在本教程中,您将看到

  • Join 上下文管理器的概述。

  • 一个如何将上下文管理器与 DistributedDataParallel 结合使用的示例。

  • 一个如何将上下文管理器与 DistributedDataParallelZeroRedundancyOptimizer 结合使用的示例。

  • 一个传递关键字参数给上下文管理器的示例。

  • 深入了解 Join 上下文管理器的工作原理。

  • 一个展示如何使玩具类与上下文管理器兼容的示例。

要求#

什么是 Join#

使用分布式数据并行入门 - 基本用例 中,您了解了使用 DistributedDataParallel 进行数据并行训练的通用框架。这会在每次反向传播时隐式调度 all-reduces 操作以同步不同 rank 之间的梯度。此类 集体通信 需要进程组中的所有 rank 参与,因此如果某个 rank 的输入较少,其他 rank 将会挂起或报错(取决于后端)。更广泛地说,这个问题会存在于任何执行每迭代同步集体通信的类中。

Join 是一个上下文管理器,用于包装每个 rank 的训练循环,以支持不均匀输入进行训练。该上下文管理器允许尽早耗尽输入的 rank(即,尽早加入)来“隐藏”尚未加入的 rank 所执行的集体通信。通信的隐藏方式由 hooks 指定。

JoinDistributedDataParallel 结合使用#

PyTorch 的 DistributedDataParallel 可与 Join 上下文管理器开箱即用。以下是一个示例用法:

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

这将产生以下输出(来自 rank 0 和 rank 1 的 print() 语句的顺序可能任意):

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

注意

在此通用 Join 上下文管理器引入之前,DistributedDataParallel 提供了自己的 join() 上下文管理器。在上面的示例中,使用 with Join([model]): 等同于使用 with model.join():。现有 DistributedDataParallel.join() 的一个限制是它不允许多个参与类,例如 DistributedDataParallelZeroRedundancyOptimizer 同时使用。

JoinDistributedDataParallelZeroRedundancyOptimizer 结合使用#

Join 上下文管理器不仅可以与单个类一起工作,还可以与多个类一起工作。PyTorch 的 ZeroRedundancyOptimizer 也与上下文管理器兼容,因此在此,我们考察如何修改之前的示例以同时使用 DistributedDataParallelZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

这将产生与之前相同的输出。值得注意的更改是额外将 ZeroRedundancyOptimizer 实例传递给了 Join()

传递关键字参数#

类可以在运行时提供关键字参数来修改其在上下文管理器中的行为。例如,DistributedDataParallel 提供了一个 divide_by_initial_world_size 参数,该参数决定梯度是除以初始 world size 还是除以有效 world size(即非加入 rank 的数量)。此类关键字参数可以直接传递给上下文管理器。

with Join([model, optim], divide_by_initial_world_size=False):
    for input in inputs:
        ...

警告

传递给上下文管理器的关键字参数会在所有参与类之间共享。这不应成为限制,因为我们不期望出现多个 Joinable 需要相同参数的不同设置的情况。尽管如此,这一点还是值得注意的。

Join 是如何工作的?#

现在我们已经看到了一些使用 Join 上下文管理器的初步示例,让我们深入了解它是如何工作的。这将帮助您更深入地了解它提供的全部功能,并为您准备好使自己的自定义类兼容。在这里,我们将介绍 Join 类以及支持类 JoinableJoinHook

Joinable#

首先,与 Join 上下文管理器兼容的类必须继承自抽象基类 Joinable。特别是,一个 Joinable 必须实现:

  • join_hook(self, **kwargs) -> JoinHook

这会返回 JoinableJoinHook 实例,该实例决定已加入的进程应如何“隐藏” Joinable 在每次迭代中执行的集体通信。

  • join_device(self) -> torch.device

这会返回一个设备,供 Join 上下文管理器用于执行集体通信,例如 torch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

这会返回一个进程组,供 Join 上下文管理器用于执行集体通信。

特别是,join_devicejoin_process_group 是必需的属性,以确保上下文管理器能够调度已加入和未加入进程之间的集体通信。一种用法是使用 all-reduce 在每次迭代中计算未加入进程的数量。另一种用法是实现 throw_on_early_termination=True 所需的机制,我们将在下文进一步解释。

DistributedDataParallelZeroRedundancyOptimizer 已经继承了 Joinable 并实现了上述方法,这就是为什么我们可以在之前的示例中直接使用它们。

Joinable 类应确保调用 Joinable 构造函数,因为它会初始化一个 JoinConfig 实例,该实例由上下文管理器内部使用以确保正确性。这将作为字段 _join_config 保存在每个 Joinable 中。

JoinHook#

接下来,我们分解 JoinHook 类。一个 JoinHook 为上下文管理器提供了两个入口点:

  • main_hook(self) -> None

此 hook 在存在尚未加入的 rank 时,被每个已加入的 rank 反复调用。它旨在“隐藏” Joinable 在每个训练迭代中执行的集体通信(例如,一次前向传播、反向传播和优化器步进)。

  • post_hook(self, is_last_joiner: bool) -> None

所有 rank 加入后,此 hook 会被调用一次。它会接收一个额外的布尔参数 is_last_joiner,该参数指示该 rank 是否是最后加入的 rank 之一。该参数可能对同步有用。

为了给出这些 hook 可能样子的具体示例,提供的 ZeroRedundancyOptimizer main hook 正常执行一次优化器步进,因为已加入的 rank 仍然负责更新和同步其参数分片;而提供的 DistributedDataParallel post-hook 则将最终更新的模型从最后一个加入的 rank 之一广播出去,以确保它在所有 rank 中都相同。

Join#

最后,让我们看看这些如何融入 Join 类本身。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我们在前面的示例中所见,构造函数接受参与训练循环的 Joinable 对象列表。这些应该是每个迭代中执行集体通信的类。

enable 是一个布尔值,如果知道不会出现不均匀输入,则可以将其设置为 False。在这种情况下,上下文管理器将变得空泛,类似于 contextlib.nullcontext()。这也会禁用参与的 Joinable 中的 join 相关计算。

throw_on_early_termination 是一个布尔值,如果设置为 True,则会在检测到不均匀输入时让每个 rank 引发一个异常。这对于不符合上下文管理器要求的情况很有用,最常见的情况是当 DistributedDataParallel 与具有 SyncBatchNorm 层的模型一起使用时,存在来自不同类的集体通信可能任意交织。在这种情况下,应将此参数设置为 True,以便应用程序逻辑可以捕获异常并确定如何继续。

  • 核心逻辑发生在 __exit__() 方法中,该方法在存在未加入的 rank 时循环,调用每个 Joinable 的 main hook,然后一旦所有 rank 都加入,就会调用它们的 post hook。main hook 和 post hook 都按照传递 Joinable 的顺序进行迭代。

  • 上下文管理器要求未加入的进程发送心跳信号。因此,每个 Joinable 类在其每次迭代的集体通信之前都应调用 Join.notify_join_context()。上下文管理器将确保只有传递的第一个 Joinable 实际发送心跳。

警告

如上所述关于 throw_on_early_terminationJoin 上下文管理器与某些类的组合不兼容。JoinableJoinHook 必须是可序列化的,因为每个 hook 在继续执行下一个之前都会完全执行。换句话说,两个 hook 不能重叠。此外,目前 main hook 和 post hook 都以相同的确定性顺序进行迭代。如果这似乎是一个主要限制,我们可能会修改 API 以允许可自定义的排序。

使玩具类与 Join 兼容#

由于上一节介绍了一些概念,让我们通过一个玩具示例在实践中看看它们。在这里,我们将实现一个类,该类计算在自己的 rank 加入之前,所有 rank 中看到的输入数量。这应该提供一个基本的想法,说明如何使您自己的类与 Join 上下文管理器兼容。

具体来说,以下代码让每个 rank 打印出 (1) 在其加入之前所有 rank 中看到的输入数量,以及 (2) 所有 rank 中的总输入数量。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由于 rank 0 看到 5 个输入,rank 1 看到 6 个输入,这将产生输出:

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

一些需要强调的关键点

  • 一个 Counter 实例每迭代执行一次 all-reduce,因此 main hook 也执行一次 all-reduce 来“隐藏”它。

  • Counter 类在其 __call__() 方法的开头调用 Join.notify_join_context(),因为这是在其每次迭代的集体通信(即其 all-reduce)之前的位置。

  • is_last_joiner 参数用于确定 post-hooks 中的广播源。

  • 我们将 sync_max_count 关键字参数传递给上下文管理器,然后该参数会被转发到 Counter 的 join hook。