PyTorch/XLA API¶
torch_xla¶
- torch_xla.device(index: int = None) device [source]¶
返回一个 XLA 设备实例。
如果启用了 SPMD,则返回一个封装此进程可用的所有设备的虚拟设备。
- 参数:
index – 要返回的 XLA 设备的索引。对应于 torch_xla.devices() 中的索引。默认情况下,获取第一个设备。
- 返回:
一个 XLA torch.device。
- torch_xla.sync(wait: bool = False, reset_scope: bool = True)[source]¶
启动所有待定的图操作。
- 参数:
wait (bool) – 是否阻塞当前进程直到执行完成。
reset_scope (bool) – 是否重置 IR 节点的 torch::lazy::ScopeContext。
- torch_xla.compile(f: Optional[Callable] = None, full_graph: Optional[bool] = False, name: Optional[str] = None, max_different_graphs: Optional[int] = None)[source]¶
使用 torch_xla 的 LazyTensor 跟踪模式优化给定的模型/函数。PyTorch/XLA 将使用给定的输入跟踪该函数,然后生成图以表示函数内发生的 PyTorch 操作。此图将由 XLA 编译并在加速器(由张量的设备决定)上执行。对于函数编译区域,将禁用即时模式。
- 参数:
model (Callable) – 要优化的模块/函数,如果未提供,此函数将作为上下文管理器。
full_graph (Optional[bool]) – 此编译是否应生成单个图。如果设置为 True 且生成多个图,torch_xla 将抛出带有调试信息的错误并退出。
name (Optional[name]) – 编译程序的名称。如果未指定,将使用函数 f 的名称。此名称将用于 PT_XLA_DEBUG 消息以及 HLO/IR 转储文件。
max_different_graphs (Optional[python:int]) – 允许给定的模型/函数具有的不同跟踪图的数量。如果超出此限制,将引发错误。
示例
# usage 1 @torch_xla.compile() def foo(x): return torch.sin(x) + torch.cos(x) def foo2(x): return torch.sin(x) + torch.cos(x) # usage 2 compiled_foo2 = torch_xla.compile(foo2) # usage 3 with torch_xla.compile(): res = foo2(x)
backends¶
torch_xla.backends 控制 XLA 后端的行为。
此子包与 PyTorch 中的 torch.backends.{cuda, cpu, mps, etc} 子包并行。
- torch_xla.backends.set_mat_mul_precision(precision: Literal['default', 'high', 'highest']) None [source]¶
控制 32 位输入的默认矩阵乘法和卷积精度。
某些平台(如 TPU)提供可配置的矩阵乘法和卷积计算精度级别,以牺牲精度换取速度。
此选项控制 32 位输入上的矩阵乘法和卷积计算的默认精度级别。级别描述了标量乘积的计算精度。
- 在 TPU 上
default 是最快、最不精确的,它在乘法之前将 FP32 降级为 BF16。
high 需要三次传递,产生大约 14 位精度。
highest 是最精确但最慢的。它需要六次传递,产生大约 22 位精度。
有关精度级别的更多信息,请参阅 [精度教程](../../tutorials/precision_tutorial.html)。
- 注意:不建议多次设置矩阵乘法精度。
如果需要这样做,请通过实验验证精度设置是否按预期工作。
- 参数:
precision (str) – 要为矩阵乘法设置的精度。必须是 ‘default’、‘high’ 或 ‘highest’ 之一。
runtime¶
- torch_xla.runtime.device_type() Optional[str] [source]¶
返回当前的 PjRt 设备类型。
如果尚未配置默认设备,则选择一个默认设备
- 返回:
设备的字符串表示。
- torch_xla.runtime.global_ordinal() int [source]¶
返回此线程在所有进程中的全局序数。
全局序数在 [0, global_device_count) 范围内。全局序数与 TPU 工作器 ID 之间没有保证可预测的关系,并且不能保证在每个主机上都是连续的。
- torch_xla.runtime.get_master_ip() str [source]¶
检索运行时的 master worker IP。此调用将进入特定于后端的发现 API。
- 返回:
master worker 的 IP 地址(字符串形式)。
- torch_xla.runtime.use_spmd(auto: Optional[bool] = False)[source]¶
启用 SPMD 模式的 API。这是启用 SPMD 的推荐方式。
如果某些张量已在非 SPMD 设备上初始化,这将强制 SPMD 模式。这意味着这些张量将在设备之间复制。
- 参数:
auto (bool) – 是否启用自动分片。有关更多详细信息,请阅读 https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#auto-sharding
xla_model¶
- torch_xla.core.xla_model.xla_device(n: Optional[int] = None, devkind: Optional[str] = None) device [source]¶
返回一个 XLA 设备实例。
- 参数:
n (python:int, optional) – 要返回的特定实例(序数)。如果指定,将返回特定的 XLA 设备实例。否则,将返回第一个设备(默认为 0)。
devkind (string..., optional) – 如果指定,则为设备类型,例如 TPU、CUDA、CPU 或自定义 PJRT 设备。已弃用。
- 返回:
具有请求的 XLA 设备实例的 torch.device。
- torch_xla.core.xla_model.xla_device_hw(device: Union[str, device]) str [source]¶
返回给定设备的硬件类型。
- 参数:
device (string or torch.device) – 将映射到真实设备的 xla 设备。
- 返回:
给定设备硬件类型的字符串表示。
- torch_xla.core.xla_model.is_master_ordinal(local: bool = True) bool [source]¶
检查当前进程是否为主序数(0)。
- 参数:
local (bool) – 应检查本地还是全局主序数。在多主机复制的情况下,只有一个全局主序数(主机 0,设备 0),而有 NUM_HOSTS 个本地主序数。默认值:True
- 返回:
一个布尔值,指示当前进程是否为主序数。
- torch_xla.core.xla_model.all_reduce(reduce_type: str, inputs: Union[Tensor, List[Tensor]], scale: float = 1.0, groups: Optional[List[List[int]]] = None, pin_layout: bool = True) Union[Tensor, List[Tensor]] [source]¶
对输入张量执行原地归约操作。
- 参数:
reduce_type (string) –
xm.REDUCE_SUM
、xm.REDUCE_MUL
、xm.REDUCE_AND
、xm.REDUCE_OR
、xm.REDUCE_MIN
和xm.REDUCE_MAX
之一。inputs – 要进行 all reduce 操作的单个 torch.Tensor 或张量列表。
scale (python:float) – 归约后应用的默认缩放值。默认为:1.0
groups (list, optional) – 列表的列表,表示 all_reduce() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]] 定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个包含所有副本的组。
pin_layout (bool, optional) – 是否为通信操作固定布局。布局固定可以防止参与通信的每个进程具有略微不同的程序时潜在的数据损坏,但这可能会导致某些 xla 编译失败。
- 返回:
如果传递单个 torch.Tensor,则返回值是包含归约值(跨副本)的 torch.Tensor。如果传递列表/元组,此函数将对输入张量执行原地 all-reduce 操作,并返回列表/元组本身。
- torch_xla.core.xla_model.all_gather(value: Tensor, dim: int = 0, groups: Optional[List[List[int]]] = None, output: Optional[Tensor] = None, pin_layout: bool = True, channel_id=None, use_global_device_ids=None) Tensor [source]¶
沿给定维度执行 all-gather 操作。
- 参数:
value (torch.Tensor) – 输入张量。
dim (python:int) – 收集维度。默认为 0
groups (list, optional) – 列表的列表,表示 all_gather() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]] 定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个包含所有副本的组。
output (torch.Tensor) – 可选输出张量。
pin_layout (bool, optional) – 是否为通信操作固定布局。布局固定可以防止参与通信的每个进程具有略微不同的程序时潜在的数据损坏,但这可能会导致某些 xla 编译失败。
channel_id (python:int, optional) – 用于跨模块通信的可选通道 ID
use_global_device_ids (bool, optional) – 如果为 true,则将 id 解释为全局设备 id
- 返回:
在
dim
维度上包含所有参与副本值的张量。
- torch_xla.core.xla_model.all_to_all(value: Tensor, split_dimension: int, concat_dimension: int, split_count: int, groups: Optional[List[List[int]]] = None, pin_layout: bool = True) Tensor [source]¶
对输入张量执行 XLA AllToAll() 操作。
参见:https://tensorflowcn.cn/xla/operation_semantics#alltoall
- 参数:
value (torch.Tensor) – 输入张量。
split_dimension (python:int) – 应在其上进行分割的维度。
concat_dimension (python:int) – 应在其上进行连接的维度。
split_count (python:int) – 分割计数。
groups (list, optional) – 列表的列表,表示 all_reduce() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]] 定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个包含所有副本的组。
pin_layout (bool, optional) – 是否为通信操作固定布局。布局固定可以防止参与通信的每个进程具有略微不同的程序时潜在的数据损坏,但这可能会导致某些 xla 编译失败。
- 返回:
all_to_all() 操作的结果 torch.Tensor。
- torch_xla.core.xla_model.add_step_closure(closure: Callable[[...], Any], args: Tuple[Any, ...] = (), run_async: bool = False)[source]¶
将一个闭包添加到将在步结束时运行的闭包列表中。
在模型训练过程中,经常需要打印/报告(例如,打印到控制台、发布到 tensorboard 等)信息,这些信息需要检查中间张量的内容。在模型代码的不同点检查不同张量的内容需要多次执行,通常会导致性能问题。添加步闭包将确保它在屏障后运行,此时所有活动张量都将已物化到设备数据。活动张量将包括闭包参数捕获的张量。因此,使用 add_step_closure() 将确保执行一次,即使有多个闭包排队,需要检查多个张量。步闭包将按它们排队的顺序顺序运行。请注意,即使使用此 API,执行也会得到优化,但建议每 N 步限制一次打印/报告事件。
- 参数:
closure (callable) – 要调用的函数。
args (tuple) – 要传递给闭包的参数。
run_async – 如果为 True,则异步运行闭包。
- torch_xla.core.xla_model.wait_device_ops(devices: List[str] = [])[source]¶
等待给定设备上的所有异步操作完成。
- 参数:
devices (string..., optional) – 需要等待其异步操作的设备。如果为空,将等待所有本地设备。
- torch_xla.core.xla_model.optimizer_step(optimizer: Optimizer, barrier: bool = False, optimizer_args: Dict = {}, groups: Optional[List[List[int]]] = None, pin_layout: bool = True)[source]¶
运行提供的优化器步长并同步所有设备上的梯度。
- 参数:
optimizer (
torch.Optimizer
) – 需要调用其 step() 函数的 torch.Optimizer 实例。将使用 optimizer_args 命名参数调用 step() 函数。barrier (bool, optional) – 是否在此 API 中发出 XLA 张量屏障。如果使用 PyTorch XLA ParallelLoader 或 DataParallel 支持,则不需要此参数,因为屏障将由 XLA 数据加载器迭代器 next() 调用发出。默认值:False
optimizer_args (dict, optional) – optimizer.step() 调用的命名参数字典。
groups (list, optional) – 列表的列表,表示 all_reduce() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]] 定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个包含所有副本的组。
pin_layout (bool, optional) – 减少梯度时是否固定布局。有关详细信息,请参阅 xm.all_reduce。
- 返回:
由 optimizer.step() 调用返回的相同值。
示例
>>> import torch_xla.core.xla_model as xm >>> xm.optimizer_step(self.optimizer)
- torch_xla.core.xla_model.save(data: Any, file_or_path: Union[str, TextIO], master_only: bool = True, global_master: bool = False)[source]¶
将输入数据保存到文件。
保存的数据将被传输到 PyTorch CPU 设备,然后保存,因此后续的 torch.load() 将加载 CPU 数据。在处理视图时必须小心。与其保存视图,不如建议在张量已加载并移至其目标设备后重新创建它们。
- 参数:
data – 要保存的输入数据。任何嵌套的 Python 对象组合(列表、元组、集合、字典等)。
file_or_path – 数据保存操作的目标。可以是文件路径,也可以是 Python 文件对象。如果 master_only 为
False
,则路径或文件对象必须指向不同的目标,否则同一主机上的所有写入都会相互覆盖。master_only (bool, optional) – 是否仅主设备保存数据。如果为 False,则 file_or_path 参数应为参与复制的每个序数指向不同的文件或路径,否则同一主机上的所有副本将写入同一位置。默认值:True
global_master (bool, optional) – 当
master_only
为True
时,此标志控制是每个主机的 master(如果global_master
为False
)保存内容,还是只有全局 master(序数 0)保存。默认值:False
示例
>>> import torch_xla.core.xla_model as xm >>> xm.wait_device_ops() # wait for all pending operations to finish. >>> xm.save(obj_to_save, path_to_save) >>> xm.rendezvous('torch_xla.core.xla_model.save') # multi process context only
- torch_xla.core.xla_model.rendezvous(tag: str, payload: bytes = b'', replicas: List[int] = []) List[bytes] [source]¶
等待所有 mesh 客户端到达命名 rendezvous。
注意:PJRT 不支持 XRT mesh 服务器,因此这实际上是 xla_rendezvous 的别名。
- 参数:
tag (string) – rendezvous 的名称。
payload (bytes, optional) – 要发送到 rendezvous 的有效负载。
replicas (list, python:int) – 参与 rendezvous 的副本序数。空表示 mesh 中的所有副本。默认值:[]
- 返回:
所有其他核心交换的有效负载,其中核心序数 i 的有效负载位于返回元组的第 i 个位置。
示例
>>> import torch_xla.core.xla_model as xm >>> xm.rendezvous('example')
- torch_xla.core.xla_model.mesh_reduce(tag: str, data, reduce_fn: Callable[[...], Any]) Union[Any, ToXlaTensorArena] [source]¶
执行图外客户端 mesh 归约。
- 参数:
tag (string) – rendezvous 的名称。
data – 要归约的数据。reduce_fn 可调用对象将接收一个列表,其中包含来自所有 mesh 客户端进程(每个核心一个)的相同数据的副本。
reduce_fn (callable) – 一个函数,它接收 data-like 对象的列表并返回归约结果。
- 返回:
归约后的值。
示例
>>> import torch_xla.core.xla_model as xm >>> import numpy as np >>> accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
- torch_xla.core.xla_model.set_rng_state(seed: int, device: Optional[str] = None)[source]¶
设置随机数生成器状态。
- 参数:
seed (python:integer) – 要设置的状态。
device (string, optional) – 需要设置 RNG 状态的设备。如果缺失,将设置默认设备种子。
- torch_xla.core.xla_model.get_rng_state(device: Optional[str] = None) int [source]¶
获取当前的随机数生成器状态。
- 参数:
device (string, optional) – 需要检索其 RNG 状态的设备。如果缺失,将设置默认设备种子。
- 返回:
RNG 状态(整数形式)。
- torch_xla.core.xla_model.get_memory_info(device: Optional[device] = None) MemoryInfo [source]¶
检索设备内存使用情况。
- 参数:
device – Optional[torch.device] 请求内存信息的设备。
device. (如果未传递,将使用默认值) –
- 返回:
包含给定设备内存使用情况的 MemoryInfo 字典。
示例
>>> xm.get_memory_info() {'bytes_used': 290816, 'bytes_limit': 34088157184, 'peak_bytes_used': 500816}
- torch_xla.core.xla_model.get_stablehlo(tensors: Optional[List[Tensor]] = None) str [source]¶
以字符串格式获取计算图的 StableHLO。
如果 tensors 非空,则将转储以 tensors 为输出的图。如果 tensors 为空,则将转储整个计算图。
对于推理图,建议将模型输出传递给 tensors。对于训练图,识别“输出”并不直接。建议使用空的 tensors。
要启用 StableHLO 中的源代码行信息,请设置环境变量 XLA_HLO_DEBUG=1。
- 参数:
tensors (list[torch.Tensor], optional) – 代表 StableHLO 图的输出/根的张量。
- 返回:
StableHLO 模块(字符串格式)。
- torch_xla.core.xla_model.get_stablehlo_bytecode(tensors: Optional[Tensor] = None) bytes [source]¶
以字节码格式获取计算图的 StableHLO。
如果 tensors 非空,则将转储以 tensors 为输出的图。如果 tensors 为空,则将转储整个计算图。
对于推理图,建议将模型输出传递给 tensors。对于训练图,识别“输出”并不直接。建议使用空的 tensors。
- 参数:
tensors (list[torch.Tensor], optional) – 代表 StableHLO 图的输出/根的张量。
- 返回:
StableHLO 模块(字节码格式)。
distributed¶
- class torch_xla.distributed.parallel_loader.MpDeviceLoader(loader, device, **kwargs)[source]¶
使用后台数据上传包装现有的 PyTorch DataLoader。
此类只能与多进程数据并行一起使用。它将包装传入的数据加载器与 ParallelLoader 一起使用,并为当前设备返回 per_device_loader。
- 参数:
loader (
torch.utils.data.DataLoader
) – 要包装的 PyTorch DataLoader。device (torch.device…) – 数据需要发送到的设备。
kwargs – ParallelLoader 构造函数的命名参数。
示例
>>> device = torch_xla.device() >>> train_device_loader = MpDeviceLoader(train_loader, device)
- torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]¶
启用基于多进程的复制。
- 参数:
fn (callable) – 要为参与复制的每个设备调用的函数。该函数将以参与复制的全局进程索引作为第一个参数调用,后跟 args 中传递的参数。
args (tuple) – fn 的参数。默认值:空元组
nprocs (python:int) – 复制的进程/设备数量。目前,如果指定,则可以是 1 或 None(后者将自动转换为最大设备数)。其他数字将导致 ValueError。
join (bool) – 调用是否应阻塞等待已启动进程的完成。默认值:True
daemon (bool) – 是否应将已启动的进程设置为 daemon 标志(参见 Python 多进程 API)。默认值:False
start_method (string) – Python multiprocessing 进程创建方法。默认值:spawn
- 返回:
与 torch.multiprocessing.spawn API 返回的对象相同。如果 nprocs 为 1,则将直接调用 fn 函数,并且 API 将返回 None。
spmd¶
- torch_xla.distributed.spmd.mark_sharding(t: Union[Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: tuple[Union[tuple[Union[int, str], ...], int, str, NoneType], ...]) XLAShardedTensor [source]¶
使用 XLA 分区规范注解提供的张量。内部,它为 XLA SpmdPartitioner 过程注解相应的 XLATensor 以进行分片。
- 参数:
t (Union[torch.Tensor, XLAShardedTensor]) – 要用 partition_spec 注解的输入张量。
mesh (Mesh) – 描述逻辑 XLA 设备拓扑和底层设备 ID。
partition_spec (PartitionSpec) –
一个或多个设备 mesh 轴的元组,用于描述如何分片输入张量。每个元素可以是
整数:按索引引用 mesh 轴
字符串:按名称引用 mesh 轴
元组:引用多个 mesh 轴
None:相应张量维度将在所有设备上复制
这指定了每个输入秩如何分片(索引到 mesh_shape)或复制(None)。当指定元组时,相应的输入张量轴将沿着元组中的所有 mesh 轴进行分片。请注意,mesh 轴在元组中的指定顺序将影响最终的分片。
示例
>>> import torch_xla.runtime as xr >>> import torch_xla.distributed.spmd as xs >>> mesh_shape = (4, 2) >>> num_devices = xr.global_runtime_device_count() >>> device_ids = np.array(range(num_devices)) >>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) >>> input = torch.randn(8, 32).to('xla') >>> xs.mark_sharding(input, mesh, (0, None)) # 4-way data parallel >>> linear = nn.Linear(32, 10).to('xla') >>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
- torch_xla.distributed.spmd.clear_sharding(t: Union[Tensor, XLAShardedTensor]) Tensor [source]¶
清除输入张量的分片注解,并返回一个 cpu 转换的张量。这是一个原地操作,但也会返回相同的 torch.Tensor。
- 参数:
t (Union[torch.Tensor, XLAShardedTensor]) – 我们想要清除分片的张量
- 返回:
无分片的张量。
- 返回类型:
t (torch.Tensor)
示例
>>> import torch_xla.distributed.spmd as xs >>> torch_xla.runtime.use_spmd() >>> t1 = torch.randn(8,8).to('xla') >>> mesh = xs.get_1d_mesh() >>> xs.mark_sharding(t1, mesh, (0, None)) >>> xs.clear_sharding(t1)
- torch_xla.distributed.spmd.set_global_mesh(mesh: Mesh)[source]¶
设置可用于当前进程的全局 mesh。
- 参数:
mesh – (Mesh) 将成为全局 mesh 的 mesh 对象。
示例
>>> import torch_xla.distributed.spmd as xs >>> mesh = xs.get_1d_mesh("data") >>> xs.set_global_mesh(mesh)
- torch_xla.distributed.spmd.get_global_mesh() Optional[Mesh] [source]¶
获取当前进程的全局 Mesh。
- 返回:
(Optional[Mesh]) 如果设置了全局 Mesh,则返回 Mesh 对象;否则返回 None。
- 返回类型:
mesh
示例
>>> import torch_xla.distributed.spmd as xs >>> xs.get_global_mesh()
- torch_xla.distributed.spmd.get_1d_mesh(axis_name: Optional[str] = None) Mesh [源代码]¶
辅助函数,返回所有设备都在一个维度上的 Mesh。
- 参数:
axis_name – (Optional[str]) 可选字符串,用于表示 Mesh 的轴名称
- 返回:
Mesh 对象
- 返回类型:
示例
>>> # This example is assuming 1 TPU v4-8 >>> import torch_xla.distributed.spmd as xs >>> mesh = xs.get_1d_mesh("data") >>> print(mesh.mesh_shape) (4,) >>> print(mesh.axis_names) ('data',)
- class torch_xla.distributed.spmd.Mesh(device_ids: Union[ndarray, list[int]], mesh_shape: tuple[int, ...], axis_names: Optional[tuple[str, ...]] = None)[源代码]¶
描述逻辑 XLA 设备拓扑 Mesh 及其底层资源。
- 参数:
device_ids – 设备(ID)的扁平化列表。列表将被重塑为 mesh_shape 形状的数组,并按行主序填充元素。每个 ID 都索引自 xr.global_runtime_device_attributes() 返回的设备列表。
mesh_shape – 一个整数元组,描述设备 Mesh 的形状。每个元素描述相应轴上的设备数量。
axis_names – Mesh 轴名称的序列。其长度应与 mesh_shape 的长度相匹配。
示例
>>> mesh_shape = (4, 2) >>> num_devices = len(xm.get_xla_supported_devices()) >>> device_ids = np.array(range(num_devices)) >>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) >>> mesh.get_logical_mesh() >>> array([[0, 1], [2, 3], [4, 5], [6, 7]]) >>> mesh.shape() OrderedDict([('x', 4), ('y', 2)])
- class torch_xla.distributed.spmd.HybridMesh(*, ici_mesh_shape: tuple[int, ...], dcn_mesh_shape: Optional[tuple[int, ...]] = None, axis_names: Optional[tuple[str, ...]] = None)[源代码]¶
- 创建通过 ICI 和 DCN 网络连接的设备混合 Mesh。
逻辑 Mesh 的形状应按网络强度递增的顺序排列,例如 [replica, data, model],其中 model 的网络通信需求最高。
- 参数:
ici_mesh_shape – 内部连接设备的逻辑 Mesh 形状。
dcn_mesh_shape – 外部连接设备的逻辑 Mesh 形状。
示例
>>> # This example is assuming 2 slices of v4-8. >>> ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor) >>> dcn_mesh_shape = (2, 1, 1) >>> mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor')) >>> print(mesh.shape()) >>> >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])
experimental¶
debug¶
- torch_xla.debug.metrics.short_metrics_report(counter_names: list = None, metric_names: list = None)[源代码]¶
检索包含完整指标和计数器报告的字符串。
- 参数:
counter_names (list) – 需要打印数据的计数器名称列表。
metric_names (list) – 需要打印数据的指标名称列表。