Torch Distributed 支持¶
在 2.5 版本之前,PyTorch/XLA 仅通过自定义 API 调用 torch_xla.core.xla_model.*
支持集体操作。在 2.5 版本中,我们在 PyTorch/XLA 中为 Dynamo 和非 Dynamo 情况都采用了 torch.distributed.*
。
集体操作的降低¶
集体操作的降低堆栈¶
在引入了可跟踪的集体通信 API后,Dynamo 可以通过在 PyTorch/XLA 中重新实现降低来支持集体操作。集体操作只能通过 torch.ops._c10d_functional
调用进行跟踪。下图展示了在 torch 和 torch_xla 之间如何降低集体操作(此处以 all_reduce
为例)

图 1. 集体操作的降低堆栈
非 Dynamo 情况¶
集体操作通过注册 ProcessGroupXla
进行降低,该类派生自 ProcessGroup
。
# torch_xla/distributed/xla_backend.py
def _create_xla_process_group(prefix_store, rank, size, timeout):
assert not xr.is_spmd(
), "XLA backend is not supported with SPMD. Please use a CPU process group instead."
return ProcessGroupXla(prefix_store, rank, size, timeout)
def _register_xla_backend():
dist.Backend.register_backend('xla', _create_xla_process_group, devices='xla')
class ProcessGroupXla(ProcessGroup):
...
def allreduce(self, tensors, all_reduce_options):
...
def allgather(self, output_tensors_list, input_tensors, opts=None):
...
当我们调用时,会初始化相应的 xla dist 后端。
def _mp_fn(rank):
dist.init_process_group("xla", init_method='xla://')
In this way, collective ops will be called based on the progress group instance:
# E.g., pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
@_exception_logger
def all_gather(tensor_list, tensor, group=None, async_op=False):
...
group = group or _get_default_group()
work = group.allgather([tensor_list], [tensor]) # uses ProcessGroupXla.allgather instead
Dynamo 情况¶
对于 Dynamo 情况,某些集体操作被重新映射到pytorch/torch/distributed/_functional_collectives.py中的新函数。例如,all_reduce()
将被映射到 all_reduce_inplace()
,最终调用 torch.ops._c10d_functional.all_reduce()
。一旦我们到达 _c10d_functional,我们就可以通过 PyTorch/Xla 降低来重写操作。
at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
std::string /*group_name*/) {...}
TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
m.impl("all_reduce", all_reduce);
}
API 描述¶
对于 2.5 版本,我们现在为 Dynamo 和非 Dynamo 情况都支持四种集体操作。我们的目标是将分布式操作 (dist op) API 与 PyTorch 的上游实现保持一致。虽然函数签名保持一致,但某些输入限制仍然适用。例如,尚未支持为分布式集体操作指定多个组。有关用法示例,请参阅test_collective_ops_tpu.py,它演示了在 Dynamo 和非 Dynamo 场景中使用 dist op。以下是每个操作的详细信息:
dist.all_reduce(input: torch.Tensor, op: dist.ReduceOp = ReduceOp.SUM)
all_reduce
通过聚合来自所有节点的数据,对 input
张量执行就地归约。
dist.all_gather_into_tensor(output, input)
all_gather_into_tensor
从所有节点收集输入张量,并就地更新 output
张量。它还返回输出的别名。
dist.reduce_scatter_tensor(output, input, op: dist.ReduceOp = ReduceOp.SUM)
reduce_scatter_tensor
跨所有节点归约输入张量,并将结果就地分发到 output
张量。它返回输出的别名。
dist.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None)
all_to_all_single
函数执行 all-to-all 通信,就地更新输出张量并返回其别名。
注意:虽然接受 output_split_sizes
和 input_split_sizes
作为参数,但它们必须为 None 或设置为全部为 1。此限制反映了在保持 PyTorch 的 API 签名和 XLA AllToAll 操作的约束之间的折衷。