评价此页

使用 Join 上下文管理器进行不均匀输入分布式训练#

创建日期:2021 年 8 月 4 日 | 最后更新:2025 年 9 月 3 日 | 最后验证:2024 年 11 月 5 日

作者Andrew Gu

注意

editgithub 上查看和编辑本教程。

注意

Join 在 PyTorch 1.10 中作为原型特性引入。此 API 可能会发生变化。

在本教程中,您将看到

  • Join 上下文管理器的概述。

  • 如何将上下文管理器与 DistributedDataParallel 结合使用的示例。

  • 如何将上下文管理器与 DistributedDataParallelZeroRedundancyOptimizer 结合使用的示例。

  • 将关键字参数传递到上下文管理器的示例。

  • 深入了解 Join 上下文管理器的工作原理。

  • 一个示例,展示如何使一个玩具类与上下文管理器兼容。

要求#

Join 是什么?#

使用分布式数据并行入门 - 基本用例 中,您看到了使用 DistributedDataParallel 执行数据并行训练的一般框架。这会在每次反向传播中隐式地安排 all-reduce 操作,以在各个 rank 之间同步梯度。此类 集体通信 需要过程组中所有 rank 的参与,因此如果某个 rank 的输入较少,则其他 rank 将会挂起或出错(具体取决于后端)。更一般地说,这个问题对于执行每迭代同步集体通信的任何类都存在。

Join 是一个上下文管理器,用于围绕您的每个 rank 的训练循环使用,以促进使用不均匀输入的训练。该上下文管理器允许尽早耗尽其输入的 rank(即,join 尽早)来屏蔽那些尚未 join 的 rank 执行的集体通信。通信的屏蔽方式由钩子指定。

JoinDistributedDataParallel 结合使用#

PyTorch 的 DistributedDataParallel 可以开箱即用地与 Join 上下文管理器一起使用。这是一个示例用法

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

这将产生以下输出(其中 rank 0 和 rank 1 的 print() 语句的顺序可能任意)

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

注意

DistributedDataParallel 在引入此通用 Join 上下文管理器之前,提供了自己的 join() 上下文管理器。在上面的示例中,使用 with Join([model]): 等效于使用 with model.join():。现有 DistributedDataParallel.join() 的一个限制是它不允许使用多个参与类,例如 DistributedDataParallelZeroRedundancyOptimizer 一起使用。

JoinDistributedDataParallelZeroRedundancyOptimizer 结合使用#

Join 上下文管理器不仅可以与单个类一起使用,还可以与多个类一起使用。PyTorch 的 ZeroRedundancyOptimizer 也与上下文管理器兼容,因此,我们在此检查如何修改前面的示例以同时使用 DistributedDataParallelZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

这将产生与之前相同的输出。值得注意的变化是额外将 ZeroRedundancyOptimizer 实例传递到 Join() 中。

传递关键字参数#

类可以提供关键字参数,这些参数可以在运行时修改其在上下文管理器中的行为。例如,DistributedDataParallel 提供了一个参数 divide_by_initial_world_size,它确定梯度是除以初始世界大小还是有效世界大小(即,未 join 的 rank 的数量)。此类关键字参数可以直接传递到上下文管理器中。

with Join([model, optim], divide_by_initial_world_size=False):
    for input in inputs:
        ...

警告

传递到上下文管理器的关键字参数在所有参与类之间共享。这不应该是一个限制,因为我们不期望多个 Joinable 需要相同参数的不同设置。尽管如此,这仍然需要牢记。

Join 的工作原理?#

现在我们已经看到了如何使用 Join 上下文管理器的几个初步示例,让我们深入了解它的工作原理。这将提供对其提供的全部功能的更深入了解,并为您准备自定义类与上下文管理器兼容。

Joinable#

首先,与 Join 上下文管理器兼容的类必须继承自抽象基类 Joinable。特别是,一个 Joinable 必须实现

  • join_hook(self, **kwargs) -> JoinHook

这将返回 JoinableJoinHook 实例,确定已 join 的进程应如何屏蔽 Joinable 在每个迭代中执行的每迭代集体通信(例如,在一个前向传递、反向传递和优化器步骤中)。

  • join_device(self) -> torch.device

这将返回 Join 上下文管理器用于执行集体通信的设备,例如 torch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

这将返回 Join 上下文管理器用于执行集体通信的过程组。

特别是,join_devicejoin_process_group 是必需的属性,以确保上下文管理器可以安排已 join 和未 join 的进程之间的集体通信。一种用法是在每个迭代中使用 all-reduce 来计算未 join 的进程数量。另一种用法是实现 throw_on_early_termination=True 所需的机制,我们稍后会解释。

DistributedDataParallelZeroRedundancyOptimizer 已经继承自 Joinable 并实现了上述方法,这就是为什么我们可以在前面的示例中直接使用它们的原因。

Joinable 类应确保调用 Joinable 构造函数,因为它会初始化一个 JoinConfig 实例,该实例由上下文管理器内部用于确保正确性。这将保存在每个 Joinable 中作为一个字段 _join_config

JoinHook#

接下来,让我们分解 JoinHook 类。一个 JoinHook 提供两个进入上下文管理器的入口点

  • main_hook(self) -> None

此钩子由每个已 join 的 rank 反复调用,只要存在尚未 join 的 rank。它的目的是屏蔽 Joinable 在每个训练迭代中执行的集体通信(例如,在一个前向传递、反向传递和优化器步骤中)。

  • post_hook(self, is_last_joiner: bool) -> None

此钩子在所有 rank 都 join 后调用一次。它传递了一个额外的 bool 参数 is_last_joiner,它指示该 rank 是否是最后 join 的 rank 之一。该参数可能对同步有用。

为了给出这些钩子可能是什么样子的具体例子,提供的 ZeroRedundancyOptimizer 主钩子像往常一样执行一个优化器步骤,因为联合后的 rank 仍然负责更新和同步其参数的分片,而提供的 DistributedDataParallel 后钩子将最终更新的模型从最后一个加入的 rank 中的一个广播出去,以确保在所有 rank 上都是相同的。

Join#

最后,让我们检查这些如何适应 Join 类本身。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

如前所述的例子中所示,构造函数接收一个参与训练循环的 Joinable 列表。这些应该是每个迭代执行集体通信的类。

enable 是一个 bool,如果知道输入不会不均匀,则可以设置为 False,在这种情况下,上下文管理器将变得空虚,类似于 contextlib.nullcontext()。 这也可能会禁用参与的 Joinable 中的与 join 相关的计算。

throw_on_early_termination 是一个 bool,可以设置为 True,以便在检测到不均匀的输入时,每个 rank 都会引发异常。 这对于不符合上下文管理器要求的场景很有用,最常见的情况是来自不同类的集体通信被任意交错,例如在使用具有 SyncBatchNorm 层的 DistributedDataParallel 时。 在这种情况下,应将此参数设置为 True,以便应用程序逻辑可以捕获异常并确定如何继续。

  • 核心逻辑发生在 __exit__() 方法中,该方法循环遍历所有未加入的 rank,调用每个 Joinable 的主钩子,然后一旦所有 rank 都已加入,调用它们的后钩子。 主钩子和后钩子都以传入的 Joinable 的顺序进行迭代。

  • 上下文管理器需要来自未加入进程的心跳。 因此,每个 Joinable 类都应在其每次迭代的集体通信之前调用 Join.notify_join_context()。 上下文管理器将确保只有传入的第一个 Joinable 实际发送心跳。

警告

如前所述关于 throw_on_early_terminationJoin 上下文管理器与某些类的组合不兼容。 JoinableJoinHook 必须是可序列化的,因为每个钩子在继续到下一个钩子之前都会被完全执行。 换句话说,两个钩子不能重叠。 此外,目前,主钩子和后钩子都以相同的确定性顺序进行迭代。 如果这似乎是一个主要限制,我们可以修改 API 以允许可定制的排序。

让一个玩具类与 Join 协同工作#

由于上一节介绍了几个概念,让我们在实践中看看它们,使用一个玩具示例。 在这里,我们将实现一个类,该类在 rank 加入之前计算所有 rank 看到的输入数量。 这应该提供一个基本的想法,关于如何使您自己的类与 Join 上下文管理器兼容。

具体来说,以下代码让每个 rank 打印出 (1) 在其 rank 加入之前所有 rank 看到的输入数量,以及 (2) 所有 rank 的总输入数量。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由于 rank 0 看到 5 个输入,而 rank 1 看到 6 个输入,因此输出结果为

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

一些关键点值得强调

  • 一个 Counter 实例每迭代一次执行一个 all-reduce,因此主钩子也执行一个 all-reduce 以对其进行阴影。

  • Counter 类在其 __call__() 方法的开头调用 Join.notify_join_context(),因为那是其每次迭代的集体通信(即,其 all-reduce)之前的一个位置。

  • is_last_joiner 参数用于确定后钩子中的广播源。

  • 我们将 sync_max_count 关键字参数传递给上下文管理器,然后将其转发给 Counter 的 join 钩子。