LocalTensor 教程:单进程 SPMD 调试#
创建日期:2026 年 1 月 7 日 | 最后更新日期:2026 年 1 月 7 日
本教程介绍了 LocalTensor,这是一种强大的调试工具,用于开发和测试分布式张量操作,而无需多个进程或 GPU。
什么是 LocalTensor?#
LocalTensor 是一个 torch.Tensor 子类,它模拟了单进程上的分布式 SPMD(单程序多数据)计算。它在内部维护了一个从 rank ID 到其对应本地张量分片的映射,允许您在没有基础设施开销的情况下调试和测试分布式代码。
主要优势#
无需多进程设置:在单个 CPU/GPU 上测试分布式算法
更快的调试周期:无需启动多个进程即可快速迭代
完全可见性:直接检查每个 rank 的张量状态
CI 友好:在单进程 CI 管道中运行分布式测试
DTensor 集成:无缝本地测试 DTensor 代码
注意
LocalTensor 仅用于 调试和测试,不用于生产环境。在本地模拟多个 rank 的开销很大。
安装和设置#
LocalTensor 是 PyTorch 分布式包的一部分。除了 PyTorch 本身之外,不需要额外的安装。
使用示例#
以下示例演示了使用 LocalTensor 的核心模式。每个示例的代码都直接来自源代码文件,这些文件也经过测试以确保正确性。测试直接调用这些相同的函数。
示例 1:基本 LocalTensor 创建和操作#
从每个 rank 的张量创建 LocalTensor
def create_local_tensor():
"""Create a LocalTensor from per-rank tensors.
Returns: (local_tensor, (expected_shape, expected_ranks, expected_rank_0, expected_rank_1))
"""
rank_0_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
rank_1_tensor = torch.tensor([[5.0, 6.0], [7.0, 8.0]])
local_tensor = LocalTensor({0: rank_0_tensor, 1: rank_1_tensor})
expected = (torch.Size([2, 2]), frozenset({0, 1}), rank_0_tensor, rank_1_tensor)
return local_tensor, expected
算术运算(应用于每个 rank)
def arithmetic_operations():
"""Demonstrate arithmetic on LocalTensor.
Returns: ((doubled, added), (expected_doubled_0, expected_doubled_1, expected_added_0))
"""
input_0 = torch.tensor([1.0, 2.0, 3.0])
input_1 = torch.tensor([4.0, 5.0, 6.0])
lt = LocalTensor({0: input_0, 1: input_1})
doubled = lt * 2
added = lt + 10
expected = (input_0 * 2, input_1 * 2, input_0 + 10)
return (doubled, added), expected
当所有分片相同时提取张量
def reconcile_identical_shards():
"""Extract a single tensor when all shards are identical.
Returns: (result, expected)
"""
value = torch.tensor([1.0, 2.0, 3.0])
lt = LocalTensor({0: value.clone(), 1: value.clone(), 2: value.clone()})
result = lt.reconcile()
return result, value
使用 LocalTensorMode 进行自动 LocalTensor 创建
def use_local_tensor_mode(world_size: int = 4):
"""Use LocalTensorMode to auto-create LocalTensors.
Returns: ((is_local, num_ranks), (expected_is_local, expected_num_ranks))
"""
with LocalTensorMode(world_size):
x = torch.ones(2, 3)
is_local = isinstance(x, LocalTensor)
num_ranks = len(x._ranks)
return (is_local, num_ranks), (True, world_size)
示例 2:模拟集体操作#
测试集体操作,如 all_reduce、broadcast 和 all_gather,而无需多个进程。
使用 SUM 进行 All-reduce
def all_reduce_sum(process_group):
"""Simulate all_reduce with SUM across ranks.
Returns: (result, expected)
"""
tensors = {
0: torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
1: torch.tensor([[5.0, 6.0], [7.0, 8.0]]),
2: torch.tensor([[9.0, 10.0], [11.0, 12.0]]),
}
expected = sum(tensors.values())
with LocalTensorMode(frozenset(tensors.keys())):
lt = LocalTensor({k: v.clone() for k, v in tensors.items()})
dist.all_reduce(lt, op=dist.ReduceOp.SUM, group=process_group)
result = lt.reconcile()
return result, expected
从源 rank 广播
def broadcast_from_rank(process_group, src_rank: int = 0):
"""Simulate broadcast from a source rank.
Returns: (result, expected)
"""
tensors = {
0: torch.tensor([10.0, 20.0, 30.0]),
1: torch.tensor([40.0, 50.0, 60.0]),
2: torch.tensor([70.0, 80.0, 90.0]),
}
expected = tensors[src_rank].clone()
with LocalTensorMode(frozenset(tensors.keys())):
lt = LocalTensor({k: v.clone() for k, v in tensors.items()})
dist.broadcast(lt, src=src_rank, group=process_group)
result = lt.reconcile()
return result, expected
All-gather 以收集来自所有 rank 的张量
def all_gather_tensors(process_group):
"""Simulate all_gather to collect tensors from all ranks.
Returns: (results_list, expected_list)
"""
tensors = {
0: torch.tensor([[1.0, 2.0]]),
1: torch.tensor([[3.0, 4.0]]),
2: torch.tensor([[5.0, 6.0]]),
}
num_ranks = len(tensors)
expected = [tensors[i].clone() for i in range(num_ranks)]
with LocalTensorMode(frozenset(tensors.keys())):
lt = LocalTensor(tensors)
output_list = [torch.zeros_like(lt) for _ in range(num_ranks)]
dist.all_gather(output_list, lt, group=process_group)
results = [out.reconcile() for out in output_list]
return results, expected
示例 3:与 DTensor 配合使用#
LocalTensor 与 DTensor 集成,用于测试分布式张量并行性。
分发张量并验证重构
def distribute_and_verify(world_size: int = 4):
"""Distribute a tensor and verify reconstruction.
Returns: ((sharded_actual, replicated_actual), (sharded_expected, replicated_expected))
"""
with LocalTensorMode(world_size):
mesh = init_device_mesh("cpu", (world_size,))
tensor = torch.arange(16).reshape(4, 4).float()
dt_sharded = distribute_tensor(tensor, mesh, [Shard(0)])
dt_replicated = distribute_tensor(tensor, mesh, [Replicate()])
sharded_actual = dt_sharded.full_tensor().reconcile()
replicated_actual = dt_replicated.to_local().reconcile()
return (sharded_actual, replicated_actual), (tensor, tensor)
分布式矩阵乘法
def dtensor_matmul(world_size: int = 4):
"""Perform matrix multiplication with DTensors.
Returns: (actual, expected)
"""
with LocalTensorMode(world_size):
mesh = init_device_mesh("cpu", (world_size,))
a = torch.randn(8, 4)
b = torch.randn(4, 6)
da = distribute_tensor(a, mesh, [Shard(0)])
db = distribute_tensor(b, mesh, [Replicate()])
dc = da @ db
expected = a @ b
actual = dc.full_tensor().reconcile()
return actual, expected
模拟分布式线性层
def dtensor_linear_layer(world_size: int = 4):
"""Simulate a distributed linear layer forward pass.
Returns: (actual, expected)
"""
batch_size, in_features, out_features = 16, 8, 4
with LocalTensorMode(world_size):
mesh = init_device_mesh("cpu", (world_size,))
x = torch.randn(batch_size, in_features)
w = torch.randn(in_features, out_features)
b = torch.randn(out_features)
dx = distribute_tensor(x, mesh, [Shard(0)])
dw = distribute_tensor(w, mesh, [Replicate()])
db = distribute_tensor(b, mesh, [Replicate()])
dy = torch.relu(dx @ dw + db)
expected = torch.relu(x @ w + b)
actual = dy.full_tensor().reconcile()
return actual, expected
示例 4:处理不均匀分片#
现实世界的分布式系统通常在 rank 之间存在不均匀的数据分布。LocalTensor 使用 LocalIntNode 处理此问题。
创建每个 rank 具有不同大小的 LocalTensor
def create_uneven_shards():
"""Create LocalTensor with different sizes per rank.
Returns: ((local_tensor, is_symint), expected_shapes_dict)
"""
tensors = {
0: torch.tensor([[1.0, 2.0, 3.0, 4.0]]), # 1 row
1: torch.tensor([[5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]), # 2 rows
2: torch.tensor([[13.0, 14.0, 15.0, 16.0]]), # 1 row
}
lt = LocalTensor(tensors)
is_symint = isinstance(lt.shape[0], torch.SymInt)
expected_shapes = {rank: t.shape for rank, t in tensors.items()}
return (lt, is_symint), expected_shapes
LocalIntNode 算术运算
def local_int_node_arithmetic():
"""LocalIntNode for per-rank integer values.
Returns: ((add_result, mul_result), (expected_add, expected_mul))
"""
values_a = {0: 10, 1: 20, 2: 30}
values_b = {0: 1, 1: 2, 2: 3}
local_a = LocalIntNode(values_a)
local_b = LocalIntNode(values_b)
result_add = local_a.add(local_b)
result_mul = local_a.mul(local_b)
expected_add = {k: values_a[k] + values_b[k] for k in values_a}
expected_mul = {k: values_a[k] * values_b[k] for k in values_a}
return (
(dict(result_add._local_ints), dict(result_mul._local_ints)),
(expected_add, expected_mul),
)
维度不能均匀分割的 DTensor
def dtensor_uneven_sharding(world_size: int = 3):
"""DTensor with unevenly divisible tensor dimension.
Returns: ((rows_per_rank, matches), expected_total_rows)
"""
total_rows = 10
with LocalTensorMode(world_size):
mesh = init_device_mesh("cpu", (world_size,))
tensor = torch.arange(total_rows * 4).reshape(total_rows, 4).float()
dt = distribute_tensor(tensor, mesh, [Shard(0)])
local = dt.to_local()
rows_per_rank = {
rank: local._local_tensors[rank].shape[0] for rank in range(world_size)
}
reconstructed = dt.full_tensor().reconcile()
matches = torch.equal(reconstructed, tensor)
return (rows_per_rank, matches), total_rows
示例 5:特定于 rank 的计算#
有时您需要在不同的 rank 上执行不同的操作。
使用 rank_map() 创建每个 rank 的值
def use_rank_map(world_size: int = 4):
"""Create LocalTensors with per-rank values using rank_map.
Returns: (values_dict, expected_dict)
"""
with LocalTensorMode(world_size) as mode:
lt = mode.rank_map(lambda rank: torch.full((2, 3), float(rank)))
values = {
rank: lt._local_tensors[rank][0, 0].item() for rank in range(world_size)
}
expected = {rank: float(rank) for rank in range(world_size)}
return values, expected
使用 tensor_map() 转换每个 rank 的分片
def use_tensor_map(world_size: int = 4):
"""Transform each shard differently using tensor_map.
Returns: (values_dict, expected_dict)
"""
with LocalTensorMode(world_size) as mode:
lt = mode.rank_map(lambda rank: torch.ones(2, 2) * (rank + 1))
def scale_by_rank(rank: int, tensor: torch.Tensor) -> torch.Tensor:
return tensor * (rank + 1)
scaled = mode.tensor_map(lt, scale_by_rank)
values = {
rank: scaled._local_tensors[rank][0, 0].item() for rank in range(world_size)
}
# (rank + 1) * (rank + 1) = (rank + 1)^2
expected = {rank: float((rank + 1) ** 2) for rank in range(world_size)}
return values, expected
暂时退出 LocalTensorMode
def disable_mode_temporarily(world_size: int = 4):
"""Temporarily exit LocalTensorMode for regular tensor ops.
Returns: ((inside_type, disabled_type), (expected_inside, expected_disabled))
"""
with LocalTensorMode(world_size) as mode:
lt = torch.ones(2, 2)
inside_type = type(lt).__name__
with mode.disable():
regular = torch.ones(2, 2)
disabled_type = type(regular).__name__
return (inside_type, disabled_type), ("LocalTensor", "Tensor")
完整源代码: example_05_rank_specific.py
示例 6:多维网格#
使用 2D/3D 设备网格进行混合并行性(例如,数据并行 + 张量并行)。
创建 2D 网格
def create_2d_mesh():
"""Create a 2D mesh for hybrid parallelism.
Returns: ((shape, dim_names, total_size), (expected_shape, expected_names, expected_size))
"""
world_size = 8
dp_size, tp_size = 4, 2
with LocalTensorMode(world_size):
mesh = init_device_mesh("cpu", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))
shape = mesh.shape
dim_names = mesh.mesh_dim_names
total_size = mesh.size()
expected = ((dp_size, tp_size), ("dp", "tp"), world_size)
return (shape, dim_names, total_size), expected
混合并行性(DP + TP)
def hybrid_parallelism():
"""Combine data parallel and tensor parallel.
Returns: (actual, expected)
"""
world_size = 8
dp_size, tp_size = 4, 2
with LocalTensorMode(world_size):
mesh = init_device_mesh("cpu", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))
x = torch.randn(16, 8)
dx = distribute_tensor(x, mesh, [Shard(0), Replicate()])
w = torch.randn(8, 12)
dw = distribute_tensor(w, mesh, [Replicate(), Shard(1)])
dy = dx @ dw
expected = x @ w
actual = dy.full_tensor().reconcile()
return actual, expected
用于 DP + TP + PP 的 3D 网格
def create_3d_mesh():
"""Create a 3D mesh for DP + TP + PP.
Returns: (actual, expected)
"""
world_size = 24
pp_size, dp_size, tp_size = 2, 3, 4
with LocalTensorMode(world_size):
mesh = init_device_mesh(
"cpu",
(pp_size, dp_size, tp_size),
mesh_dim_names=("pp", "dp", "tp"),
)
tensor = torch.randn(8, 16, 32)
dt = distribute_tensor(tensor, mesh, [Replicate(), Shard(0), Shard(2)])
actual = dt.full_tensor().reconcile()
return actual, tensor
完整源代码: example_06_multidim_mesh.py
测试教程示例#
本教程中的所有示例均经过测试以确保正确性。测试套件直接调用上述相同的函数
# From test_local_tensor_tutorial_examples.py
from example_01_basic_operations import create_local_tensor
def test_create_local_tensor(self):
lt = create_local_tensor()
self.assertIsInstance(lt, LocalTensor)
self.assertEqual(lt.shape, torch.Size([2, 2]))
API 参考#
核心类#
- class torch.distributed._local_tensor.LocalTensor(local_tensors, requires_grad=False)[源代码]#
LocalTensor 是一个 Tensor 子类,它模拟了跨多个 SPMD(单程序多数据)rank 分布的张量。每个 LocalTensor 实例在内部保存了一个从全局 rank ID 到其对应本地 Tensor 分片的映射。对 LocalTensor 执行的操作独立地应用于每个本地分片,从而模拟分布式计算。集体操作和其他分布式操作通过将它们映射到本地分片来处理。
注意
此类主要用于调试和模拟单进程上的分布式张量计算。
- 返回类型:
- class torch.distributed._local_tensor.LocalTensorMode(ranks)[源代码]#
一种 TorchDispatchMode,它模拟了 LocalTensor 对象在一组 rank 上的 SPMD(单程序多数据)执行。
LocalTensorMode 能够使 PyTorch 操作透明地应用于 LocalTensor 的每个本地分片,就像它们在多个 rank 上分布式一样。激活后,此模式会拦截张量操作并将其分派到每个 rank 的本地张量,收集并包装结果为 LocalTensor。它还通过将它们映射到本地实现来处理集体操作。
此模式主要用于调试和模拟单进程上的分布式张量计算,而不是用于高性能分布式训练。它维护一个活动模式堆栈,修补 DeviceMesh 坐标解析,并提供用于临时禁用模式或映射 rank 上函数的实用程序。
- disable()[源代码]#
暂时禁用 LocalTensorMode。主要用于执行特定于 rank 的计算并在重新启用 LocalTensorMode 之前合并结果。
- 返回类型:
Generator[None, None, None]
- class torch.distributed._local_tensor.LocalIntNode(local_ints)[source]#
类似于 LocalTensor,但用于 int 类型。我们不能使用 0D 张量来表示它,因为通常只有 SymInt 才能在希望使用它的地方被接受。
- 返回类型:
ConstantIntNode | LocalIntNode
实用函数#
- torch.distributed._local_tensor.local_tensor_mode()[source]#
返回当前活动的 LocalTensorMode(如果存在)。
此函数检查 LocalTensorMode 实例的全局堆栈。如果至少有一个 LocalTensorMode 处于活动状态,它将返回最近进入的(堆栈顶部)LocalTensorMode。如果没有 LocalTensorMode 处于活动状态,它将返回 None。
- 返回:
当前活动的 LocalTensorMode(如果处于活动状态),否则为 None。
- 返回类型:
Optional[LocalTensorMode]
- torch.distributed._local_tensor.enabled_local_tensor_mode()[source]#
仅当其已启用时,才返回当前活动的 LocalTensorMode。
这是一个方便函数,它结合了检查 local_tensor_mode() 是否不为 None 且未禁用的常见模式。
- 返回:
当前活动的 LocalTensorMode(如果处于活动状态且已启用),否则为 None。
- 返回类型:
Optional[LocalTensorMode]
- torch.distributed._local_tensor.maybe_run_for_local_tensor(func)[source]#
装饰器,可确保在 LocalTensorMode 下运行时,函数针对每个本地张量分片执行。如果不在 LocalTensorMode 中,则正常执行该函数。在 LocalTensorMode 中,该函数针对每个 rank 运行,并且结果会适当地收集。
此装饰器对于需要 rank 特定操作(例如,根据 rank 计算输入张量偏移量)的函数非常有用。
请注意,被装饰的函数不应有任何副作用,并且仅包含单个 rank 的操作。例如,包装执行集体操作的函数将不起作用。
- 参数:
func (Callable[..., Any]) – 要装饰的函数。
- 返回:
处理 LocalTensorMode 逻辑的包装函数。
- 返回类型:
Callable[…, Any]
最佳实践#
仅用于测试:LocalTensor 具有显著的开销,不应在生产代码中使用。
初始化进程组:即使是本地测试,也需要初始化进程组(使用“fake”后端)。
避免在内部张量上使用 requires_grad:LocalTensor 期望内部张量不具有
requires_grad=True。在 LocalTensor 包装器上设置梯度。为了断言进行协调:使用
reconcile()在所有 rank 应该具有相同值时提取单个张量(例如,在 all-reduce 之后)。使用直接访问进行调试:通过
tensor._local_tensors[rank]访问各个分片以进行调试。
常见陷阱#
忘记上下文管理器:在
LocalTensorMode之外对 LocalTensor 进行的操作仍然有效,但不会从工厂创建新的 LocalTensor。不匹配的 rank:确保操作中的所有 LocalTensor 具有兼容的 rank。
内部张量梯度:从具有
requires_grad=True的张量创建 LocalTensor 将引发错误。