评价此页

LocalTensor 教程:单进程 SPMD 调试#

创建日期:2026 年 1 月 7 日 | 最后更新日期:2026 年 1 月 7 日

本教程介绍了 LocalTensor,这是一种强大的调试工具,用于开发和测试分布式张量操作,而无需多个进程或 GPU。

什么是 LocalTensor?#

LocalTensor 是一个 torch.Tensor 子类,它模拟了单进程上的分布式 SPMD(单程序多数据)计算。它在内部维护了一个从 rank ID 到其对应本地张量分片的映射,允许您在没有基础设施开销的情况下调试和测试分布式代码。

主要优势#

  1. 无需多进程设置:在单个 CPU/GPU 上测试分布式算法

  2. 更快的调试周期:无需启动多个进程即可快速迭代

  3. 完全可见性:直接检查每个 rank 的张量状态

  4. CI 友好:在单进程 CI 管道中运行分布式测试

  5. DTensor 集成:无缝本地测试 DTensor 代码

注意

LocalTensor 仅用于 调试和测试,不用于生产环境。在本地模拟多个 rank 的开销很大。

安装和设置#

LocalTensor 是 PyTorch 分布式包的一部分。除了 PyTorch 本身之外,不需要额外的安装。

使用示例#

以下示例演示了使用 LocalTensor 的核心模式。每个示例的代码都直接来自源代码文件,这些文件也经过测试以确保正确性。测试直接调用这些相同的函数。

示例 1:基本 LocalTensor 创建和操作#

从每个 rank 的张量创建 LocalTensor

def create_local_tensor():
    """Create a LocalTensor from per-rank tensors.

    Returns: (local_tensor, (expected_shape, expected_ranks, expected_rank_0, expected_rank_1))
    """
    rank_0_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
    rank_1_tensor = torch.tensor([[5.0, 6.0], [7.0, 8.0]])

    local_tensor = LocalTensor({0: rank_0_tensor, 1: rank_1_tensor})

    expected = (torch.Size([2, 2]), frozenset({0, 1}), rank_0_tensor, rank_1_tensor)
    return local_tensor, expected


算术运算(应用于每个 rank)

def arithmetic_operations():
    """Demonstrate arithmetic on LocalTensor.

    Returns: ((doubled, added), (expected_doubled_0, expected_doubled_1, expected_added_0))
    """
    input_0 = torch.tensor([1.0, 2.0, 3.0])
    input_1 = torch.tensor([4.0, 5.0, 6.0])

    lt = LocalTensor({0: input_0, 1: input_1})

    doubled = lt * 2
    added = lt + 10

    expected = (input_0 * 2, input_1 * 2, input_0 + 10)
    return (doubled, added), expected


当所有分片相同时提取张量

def reconcile_identical_shards():
    """Extract a single tensor when all shards are identical.

    Returns: (result, expected)
    """
    value = torch.tensor([1.0, 2.0, 3.0])
    lt = LocalTensor({0: value.clone(), 1: value.clone(), 2: value.clone()})

    result = lt.reconcile()
    return result, value


使用 LocalTensorMode 进行自动 LocalTensor 创建

def use_local_tensor_mode(world_size: int = 4):
    """Use LocalTensorMode to auto-create LocalTensors.

    Returns: ((is_local, num_ranks), (expected_is_local, expected_num_ranks))
    """
    with LocalTensorMode(world_size):
        x = torch.ones(2, 3)
        is_local = isinstance(x, LocalTensor)
        num_ranks = len(x._ranks)

    return (is_local, num_ranks), (True, world_size)


完整源代码: example_01_basic_operations.py

示例 2:模拟集体操作#

测试集体操作,如 all_reducebroadcastall_gather,而无需多个进程。

使用 SUM 进行 All-reduce

def all_reduce_sum(process_group):
    """Simulate all_reduce with SUM across ranks.

    Returns: (result, expected)
    """
    tensors = {
        0: torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
        1: torch.tensor([[5.0, 6.0], [7.0, 8.0]]),
        2: torch.tensor([[9.0, 10.0], [11.0, 12.0]]),
    }

    expected = sum(tensors.values())

    with LocalTensorMode(frozenset(tensors.keys())):
        lt = LocalTensor({k: v.clone() for k, v in tensors.items()})
        dist.all_reduce(lt, op=dist.ReduceOp.SUM, group=process_group)
        result = lt.reconcile()

    return result, expected


从源 rank 广播

def broadcast_from_rank(process_group, src_rank: int = 0):
    """Simulate broadcast from a source rank.

    Returns: (result, expected)
    """
    tensors = {
        0: torch.tensor([10.0, 20.0, 30.0]),
        1: torch.tensor([40.0, 50.0, 60.0]),
        2: torch.tensor([70.0, 80.0, 90.0]),
    }

    expected = tensors[src_rank].clone()

    with LocalTensorMode(frozenset(tensors.keys())):
        lt = LocalTensor({k: v.clone() for k, v in tensors.items()})
        dist.broadcast(lt, src=src_rank, group=process_group)
        result = lt.reconcile()

    return result, expected


All-gather 以收集来自所有 rank 的张量

def all_gather_tensors(process_group):
    """Simulate all_gather to collect tensors from all ranks.

    Returns: (results_list, expected_list)
    """
    tensors = {
        0: torch.tensor([[1.0, 2.0]]),
        1: torch.tensor([[3.0, 4.0]]),
        2: torch.tensor([[5.0, 6.0]]),
    }
    num_ranks = len(tensors)

    expected = [tensors[i].clone() for i in range(num_ranks)]

    with LocalTensorMode(frozenset(tensors.keys())):
        lt = LocalTensor(tensors)
        output_list = [torch.zeros_like(lt) for _ in range(num_ranks)]
        dist.all_gather(output_list, lt, group=process_group)
        results = [out.reconcile() for out in output_list]

    return results, expected


完整源代码: example_02_collective_operations.py

示例 3:与 DTensor 配合使用#

LocalTensor 与 DTensor 集成,用于测试分布式张量并行性。

分发张量并验证重构

def distribute_and_verify(world_size: int = 4):
    """Distribute a tensor and verify reconstruction.

    Returns: ((sharded_actual, replicated_actual), (sharded_expected, replicated_expected))
    """
    with LocalTensorMode(world_size):
        mesh = init_device_mesh("cpu", (world_size,))
        tensor = torch.arange(16).reshape(4, 4).float()

        dt_sharded = distribute_tensor(tensor, mesh, [Shard(0)])
        dt_replicated = distribute_tensor(tensor, mesh, [Replicate()])

        sharded_actual = dt_sharded.full_tensor().reconcile()
        replicated_actual = dt_replicated.to_local().reconcile()

    return (sharded_actual, replicated_actual), (tensor, tensor)


分布式矩阵乘法

def dtensor_matmul(world_size: int = 4):
    """Perform matrix multiplication with DTensors.

    Returns: (actual, expected)
    """
    with LocalTensorMode(world_size):
        mesh = init_device_mesh("cpu", (world_size,))

        a = torch.randn(8, 4)
        b = torch.randn(4, 6)

        da = distribute_tensor(a, mesh, [Shard(0)])
        db = distribute_tensor(b, mesh, [Replicate()])

        dc = da @ db

        expected = a @ b
        actual = dc.full_tensor().reconcile()

    return actual, expected


模拟分布式线性层

def dtensor_linear_layer(world_size: int = 4):
    """Simulate a distributed linear layer forward pass.

    Returns: (actual, expected)
    """
    batch_size, in_features, out_features = 16, 8, 4

    with LocalTensorMode(world_size):
        mesh = init_device_mesh("cpu", (world_size,))

        x = torch.randn(batch_size, in_features)
        w = torch.randn(in_features, out_features)
        b = torch.randn(out_features)

        dx = distribute_tensor(x, mesh, [Shard(0)])
        dw = distribute_tensor(w, mesh, [Replicate()])
        db = distribute_tensor(b, mesh, [Replicate()])

        dy = torch.relu(dx @ dw + db)

        expected = torch.relu(x @ w + b)
        actual = dy.full_tensor().reconcile()

    return actual, expected


完整源代码: example_03_dtensor_integration.py

示例 4:处理不均匀分片#

现实世界的分布式系统通常在 rank 之间存在不均匀的数据分布。LocalTensor 使用 LocalIntNode 处理此问题。

创建每个 rank 具有不同大小的 LocalTensor

def create_uneven_shards():
    """Create LocalTensor with different sizes per rank.

    Returns: ((local_tensor, is_symint), expected_shapes_dict)
    """
    tensors = {
        0: torch.tensor([[1.0, 2.0, 3.0, 4.0]]),  # 1 row
        1: torch.tensor([[5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]),  # 2 rows
        2: torch.tensor([[13.0, 14.0, 15.0, 16.0]]),  # 1 row
    }

    lt = LocalTensor(tensors)
    is_symint = isinstance(lt.shape[0], torch.SymInt)

    expected_shapes = {rank: t.shape for rank, t in tensors.items()}
    return (lt, is_symint), expected_shapes


LocalIntNode 算术运算

def local_int_node_arithmetic():
    """LocalIntNode for per-rank integer values.

    Returns: ((add_result, mul_result), (expected_add, expected_mul))
    """
    values_a = {0: 10, 1: 20, 2: 30}
    values_b = {0: 1, 1: 2, 2: 3}
    local_a = LocalIntNode(values_a)
    local_b = LocalIntNode(values_b)

    result_add = local_a.add(local_b)
    result_mul = local_a.mul(local_b)

    expected_add = {k: values_a[k] + values_b[k] for k in values_a}
    expected_mul = {k: values_a[k] * values_b[k] for k in values_a}

    return (
        (dict(result_add._local_ints), dict(result_mul._local_ints)),
        (expected_add, expected_mul),
    )


维度不能均匀分割的 DTensor

def dtensor_uneven_sharding(world_size: int = 3):
    """DTensor with unevenly divisible tensor dimension.

    Returns: ((rows_per_rank, matches), expected_total_rows)
    """
    total_rows = 10

    with LocalTensorMode(world_size):
        mesh = init_device_mesh("cpu", (world_size,))

        tensor = torch.arange(total_rows * 4).reshape(total_rows, 4).float()
        dt = distribute_tensor(tensor, mesh, [Shard(0)])

        local = dt.to_local()
        rows_per_rank = {
            rank: local._local_tensors[rank].shape[0] for rank in range(world_size)
        }

        reconstructed = dt.full_tensor().reconcile()
        matches = torch.equal(reconstructed, tensor)

    return (rows_per_rank, matches), total_rows


完整源代码: example_04_uneven_sharding.py

示例 5:特定于 rank 的计算#

有时您需要在不同的 rank 上执行不同的操作。

使用 rank_map() 创建每个 rank 的值

def use_rank_map(world_size: int = 4):
    """Create LocalTensors with per-rank values using rank_map.

    Returns: (values_dict, expected_dict)
    """
    with LocalTensorMode(world_size) as mode:
        lt = mode.rank_map(lambda rank: torch.full((2, 3), float(rank)))
        values = {
            rank: lt._local_tensors[rank][0, 0].item() for rank in range(world_size)
        }

    expected = {rank: float(rank) for rank in range(world_size)}
    return values, expected


使用 tensor_map() 转换每个 rank 的分片

def use_tensor_map(world_size: int = 4):
    """Transform each shard differently using tensor_map.

    Returns: (values_dict, expected_dict)
    """
    with LocalTensorMode(world_size) as mode:
        lt = mode.rank_map(lambda rank: torch.ones(2, 2) * (rank + 1))

        def scale_by_rank(rank: int, tensor: torch.Tensor) -> torch.Tensor:
            return tensor * (rank + 1)

        scaled = mode.tensor_map(lt, scale_by_rank)
        values = {
            rank: scaled._local_tensors[rank][0, 0].item() for rank in range(world_size)
        }

    # (rank + 1) * (rank + 1) = (rank + 1)^2
    expected = {rank: float((rank + 1) ** 2) for rank in range(world_size)}
    return values, expected


暂时退出 LocalTensorMode

def disable_mode_temporarily(world_size: int = 4):
    """Temporarily exit LocalTensorMode for regular tensor ops.

    Returns: ((inside_type, disabled_type), (expected_inside, expected_disabled))
    """
    with LocalTensorMode(world_size) as mode:
        lt = torch.ones(2, 2)
        inside_type = type(lt).__name__

        with mode.disable():
            regular = torch.ones(2, 2)
            disabled_type = type(regular).__name__

    return (inside_type, disabled_type), ("LocalTensor", "Tensor")


完整源代码: example_05_rank_specific.py

示例 6:多维网格#

使用 2D/3D 设备网格进行混合并行性(例如,数据并行 + 张量并行)。

创建 2D 网格

def create_2d_mesh():
    """Create a 2D mesh for hybrid parallelism.

    Returns: ((shape, dim_names, total_size), (expected_shape, expected_names, expected_size))
    """
    world_size = 8
    dp_size, tp_size = 4, 2

    with LocalTensorMode(world_size):
        mesh = init_device_mesh("cpu", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))
        shape = mesh.shape
        dim_names = mesh.mesh_dim_names
        total_size = mesh.size()

    expected = ((dp_size, tp_size), ("dp", "tp"), world_size)
    return (shape, dim_names, total_size), expected


混合并行性(DP + TP)

def hybrid_parallelism():
    """Combine data parallel and tensor parallel.

    Returns: (actual, expected)
    """
    world_size = 8
    dp_size, tp_size = 4, 2

    with LocalTensorMode(world_size):
        mesh = init_device_mesh("cpu", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))

        x = torch.randn(16, 8)
        dx = distribute_tensor(x, mesh, [Shard(0), Replicate()])

        w = torch.randn(8, 12)
        dw = distribute_tensor(w, mesh, [Replicate(), Shard(1)])

        dy = dx @ dw

        expected = x @ w
        actual = dy.full_tensor().reconcile()

    return actual, expected


用于 DP + TP + PP 的 3D 网格

def create_3d_mesh():
    """Create a 3D mesh for DP + TP + PP.

    Returns: (actual, expected)
    """
    world_size = 24
    pp_size, dp_size, tp_size = 2, 3, 4

    with LocalTensorMode(world_size):
        mesh = init_device_mesh(
            "cpu",
            (pp_size, dp_size, tp_size),
            mesh_dim_names=("pp", "dp", "tp"),
        )

        tensor = torch.randn(8, 16, 32)
        dt = distribute_tensor(tensor, mesh, [Replicate(), Shard(0), Shard(2)])

        actual = dt.full_tensor().reconcile()

    return actual, tensor


完整源代码: example_06_multidim_mesh.py

测试教程示例#

本教程中的所有示例均经过测试以确保正确性。测试套件直接调用上述相同的函数

# From test_local_tensor_tutorial_examples.py
from example_01_basic_operations import create_local_tensor

def test_create_local_tensor(self):
    lt = create_local_tensor()
    self.assertIsInstance(lt, LocalTensor)
    self.assertEqual(lt.shape, torch.Size([2, 2]))

测试套件: test_local_tensor_tutorial_examples.py

API 参考#

核心类#

class torch.distributed._local_tensor.LocalTensor(local_tensors, requires_grad=False)[源代码]#

LocalTensor 是一个 Tensor 子类,它模拟了跨多个 SPMD(单程序多数据)rank 分布的张量。每个 LocalTensor 实例在内部保存了一个从全局 rank ID 到其对应本地 Tensor 分片的映射。对 LocalTensor 执行的操作独立地应用于每个本地分片,从而模拟分布式计算。集体操作和其他分布式操作通过将它们映射到本地分片来处理。

注意

此类主要用于调试和模拟单进程上的分布式张量计算。

返回类型:

LocalTensor

reconcile()[源代码]#

通过确保所有本地分片相同并返回其中一个分片的 detached 克隆来协调 LocalTensor 为单个 torch.Tensor。

注意

当预期所有分片都相同时,此方法对于从 LocalTensor 提取代表性张量很有用,例如在同步所有 rank 的集体操作之后。

返回类型:

张量

tolist()[源代码]#

尝试协调,如果成功则转换为列表,否则如果 dtype 是整数,则转换为本地整数列表。

返回类型:

list[Any]

class torch.distributed._local_tensor.LocalTensorMode(ranks)[源代码]#

一种 TorchDispatchMode,它模拟了 LocalTensor 对象在一组 rank 上的 SPMD(单程序多数据)执行。

LocalTensorMode 能够使 PyTorch 操作透明地应用于 LocalTensor 的每个本地分片,就像它们在多个 rank 上分布式一样。激活后,此模式会拦截张量操作并将其分派到每个 rank 的本地张量,收集并包装结果为 LocalTensor。它还通过将它们映射到本地实现来处理集体操作。

此模式主要用于调试和模拟单进程上的分布式张量计算,而不是用于高性能分布式训练。它维护一个活动模式堆栈,修补 DeviceMesh 坐标解析,并提供用于临时禁用模式或映射 rank 上函数的实用程序。

disable()[源代码]#

暂时禁用 LocalTensorMode。主要用于执行特定于 rank 的计算并在重新启用 LocalTensorMode 之前合并结果。

返回类型:

Generator[None, None, None]

rank_map(cb)[源代码]#

通过将 rank ID 映射到 ID 本地分片来创建 LocalTensor 实例。

返回类型:

LocalTensor

tensor_map(tensor, cb)[source]#

通过将 rank ID 映射到 ID 本地分片来创建 LocalTensor 实例。

返回类型:

LocalTensor

class torch.distributed._local_tensor.LocalIntNode(local_ints)[source]#

类似于 LocalTensor,但用于 int 类型。我们不能使用 0D 张量来表示它,因为通常只有 SymInt 才能在希望使用它的地方被接受。

返回类型:

ConstantIntNode | LocalIntNode

实用函数#

torch.distributed._local_tensor.local_tensor_mode()[source]#

返回当前活动的 LocalTensorMode(如果存在)。

此函数检查 LocalTensorMode 实例的全局堆栈。如果至少有一个 LocalTensorMode 处于活动状态,它将返回最近进入的(堆栈顶部)LocalTensorMode。如果没有 LocalTensorMode 处于活动状态,它将返回 None。

返回:

当前活动的 LocalTensorMode(如果处于活动状态),否则为 None。

返回类型:

Optional[LocalTensorMode]

torch.distributed._local_tensor.enabled_local_tensor_mode()[source]#

仅当其已启用时,才返回当前活动的 LocalTensorMode。

这是一个方便函数,它结合了检查 local_tensor_mode() 是否不为 None 且未禁用的常见模式。

返回:

当前活动的 LocalTensorMode(如果处于活动状态且已启用),否则为 None。

返回类型:

Optional[LocalTensorMode]

torch.distributed._local_tensor.maybe_run_for_local_tensor(func)[source]#

装饰器,可确保在 LocalTensorMode 下运行时,函数针对每个本地张量分片执行。如果不在 LocalTensorMode 中,则正常执行该函数。在 LocalTensorMode 中,该函数针对每个 rank 运行,并且结果会适当地收集。

此装饰器对于需要 rank 特定操作(例如,根据 rank 计算输入张量偏移量)的函数非常有用。

请注意,被装饰的函数不应有任何副作用,并且仅包含单个 rank 的操作。例如,包装执行集体操作的函数将不起作用。

参数:

func (Callable[..., Any]) – 要装饰的函数。

返回:

处理 LocalTensorMode 逻辑的包装函数。

返回类型:

Callable[…, Any]

torch.distributed._local_tensor.maybe_disable_local_tensor_mode()[source]#

上下文管理器,在上下文持续期间禁用 LocalTensorMode。

返回类型:

AbstractContextManager

最佳实践#

  1. 仅用于测试:LocalTensor 具有显著的开销,不应在生产代码中使用。

  2. 初始化进程组:即使是本地测试,也需要初始化进程组(使用“fake”后端)。

  3. 避免在内部张量上使用 requires_grad:LocalTensor 期望内部张量不具有 requires_grad=True。在 LocalTensor 包装器上设置梯度。

  4. 为了断言进行协调:使用 reconcile() 在所有 rank 应该具有相同值时提取单个张量(例如,在 all-reduce 之后)。

  5. 使用直接访问进行调试:通过 tensor._local_tensors[rank] 访问各个分片以进行调试。

常见陷阱#

  1. 忘记上下文管理器:在 LocalTensorMode 之外对 LocalTensor 进行的操作仍然有效,但不会从工厂创建新的 LocalTensor。

  2. 不匹配的 rank:确保操作中的所有 LocalTensor 具有兼容的 rank。

  3. 内部张量梯度:从具有 requires_grad=True 的张量创建 LocalTensor 将引发错误。