• 文档 >
  • 对 Torch Distributed 的支持
快捷方式

Torch Distributed 支持

在 2.5 版本之前,PyTorch/XLA 仅通过自定义 API 调用 torch_xla.core.xla_model.* 支持集体操作。在 2.5 版本中,我们在 PyTorch/XLA 中为 Dynamo 和非 Dynamo 情况都采用了 torch.distributed.*

集体操作的降低

集体操作的降低堆栈

在引入了可跟踪的集体通信 API后,Dynamo 可以通过在 PyTorch/XLA 中重新实现降低来支持集体操作。集体操作只能通过 torch.ops._c10d_functional 调用进行跟踪。下图展示了在 torch 和 torch_xla 之间如何降低集体操作(此处以 all_reduce 为例)

Alt Text

图 1. 集体操作的降低堆栈

非 Dynamo 情况

集体操作通过注册 ProcessGroupXla 进行降低,该类派生自 ProcessGroup

# torch_xla/distributed/xla_backend.py
def _create_xla_process_group(prefix_store, rank, size, timeout):
  assert not xr.is_spmd(
  ), "XLA backend is not supported with SPMD. Please use a CPU process group instead."
  return ProcessGroupXla(prefix_store, rank, size, timeout)


def _register_xla_backend():
  dist.Backend.register_backend('xla', _create_xla_process_group, devices='xla')


class ProcessGroupXla(ProcessGroup):
  ...
  def allreduce(self, tensors, all_reduce_options):
    ...
  def allgather(self, output_tensors_list, input_tensors, opts=None):
    ...

当我们调用时,会初始化相应的 xla dist 后端。

def _mp_fn(rank):
  dist.init_process_group("xla", init_method='xla://')

In this way, collective ops will be called based on the progress group instance:

  # E.g., pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
  @_exception_logger
  def all_gather(tensor_list, tensor, group=None, async_op=False):
    ...
    group = group or _get_default_group()
    work = group.allgather([tensor_list], [tensor]) # uses ProcessGroupXla.allgather instead

Dynamo 情况

对于 Dynamo 情况,某些集体操作被重新映射到pytorch/torch/distributed/_functional_collectives.py中的新函数。例如,all_reduce() 将被映射到 all_reduce_inplace(),最终调用 torch.ops._c10d_functional.all_reduce()。一旦我们到达 _c10d_functional,我们就可以通过 PyTorch/Xla 降低来重写操作。

at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
                      std::string /*group_name*/)  {...}

TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
  m.impl("all_reduce", all_reduce);
}

API 描述

对于 2.5 版本,我们现在为 Dynamo 和非 Dynamo 情况都支持四种集体操作。我们的目标是将分布式操作 (dist op) API 与 PyTorch 的上游实现保持一致。虽然函数签名保持一致,但某些输入限制仍然适用。例如,尚未支持为分布式集体操作指定多个组。有关用法示例,请参阅test_collective_ops_tpu.py,它演示了在 Dynamo 和非 Dynamo 场景中使用 dist op。以下是每个操作的详细信息:

dist.all_reduce(input: torch.Tensor, op: dist.ReduceOp = ReduceOp.SUM)

all_reduce 通过聚合来自所有节点的数据,对 input 张量执行就地归约。

dist.all_gather_into_tensor(output, input)

all_gather_into_tensor 从所有节点收集输入张量,并就地更新 output 张量。它还返回输出的别名。

dist.reduce_scatter_tensor(output, input, op: dist.ReduceOp = ReduceOp.SUM)

reduce_scatter_tensor 跨所有节点归约输入张量,并将结果就地分发到 output 张量。它返回输出的别名。

dist.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None)

all_to_all_single 函数执行 all-to-all 通信,就地更新输出张量并返回其别名。

注意:虽然接受 output_split_sizesinput_split_sizes 作为参数,但它们必须为 None 或设置为全部为 1。此限制反映了在保持 PyTorch 的 API 签名和 XLA AllToAll 操作的约束之间的折衷。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源