评价此页

使用 Monarch 实现交互式分布式应用#

作者: Amir Afzali

简介#

随着深度学习模型规模和复杂度的不断增长,高效地训练它们需要跨多个 GPU 和节点协调计算。在本教程中,您将学习如何使用 Monarch 的 Actor 框架以及 TorchTitan,在一个由 SLURM 管理的集群上轻松设置和运行大规模分布式工作流。Monarch 将使我们能够像在单台主机、单进程环境中开发一样,驱动一个大型机器集群(组织成一个网格)。

什么是 Monarch?#

Monarch 是一个 Actor 框架,旨在简化分布式应用程序的开发。其核心是,Monarch 提供了:

  • 基于 Actor 的编程模型:将有状态计算封装在可以远程进程和机器上运行的 Actor 中

  • 进程网格抽象:轻松管理和协调集群中的分布式进程,支持可伸缩的 Actor 消息传递

  • 容错性:Actor 和进程构成一个树状结构,故障会沿着树向上传播,提供良好的默认错误行为并实现细粒度的故障恢复。

  • 灵活的资源管理:支持多种集群调度器,包括 SLURM、Kubernetes、自定义主机管理和本地进程

  • 集成监控:将来自远程进程的日志流式传输回客户端,便于调试和聚合

有关更多详细信息,请参阅 Monarch 文档

为什么使用 Monarch?#

TorchTitan 是一个 PyTorch 原生的大规模预训练库。虽然 TorchTitan 提供了出色的分布式训练原语,但在集群上启动和管理这些作业可能会减慢迭代速度。Monarch 通过以下方式解决了这个问题:

  1. 简化的集群交互:使用简单的异步 Python 调用来预留和管理计算资源,而不是编写 bash 脚本

  2. 交互式开发:在现有分配上修改和重新运行训练代码,而无需等待新资源

  3. 统一的工作流:使用相同的代码在本地测试和集群执行之间无缝切换

先决条件#

本教程依赖于 Titan 的 nightly build,因此请确保其他 Torch 库也跟踪 nightly builds。

  1. 已安装 Monarch nightly 版本: 安装脚本

  2. 已安装 TorchTitan nightly 版本: TorchTitan 安装说明

  3. 在您的工作目录中**一个有效的 Titan 模型配置文件**和**分词器**(例如,来自 TorchTitan 配置 的 `debug_model.toml`)。

  4. SLURM 集群访问权限

    • 拥有预留节点和启动作业的足够权限。

    • 为分布式 GPU 训练配置的 CUDA 环境。

现在,让我们一步步来实现!

步骤 1:预留机器资源#

首先,我们将定义一个函数以编程方式预留机器分配。

Monarch 亮点:您无需提交 SBATCH 脚本,就可以从 Python 中交互式地预留和管理资源。JobTrait 设计模式允许通过一致的 API 与自定义调度器(如 SLURM 和 Kubernetes)进行交互。

from monarch.job import SlurmJob, JobTrait


def create_slurm_job(
    mesh_name: str,
    num_nodes: int,
    gpus_per_node: int,
    time_limit: str = "06:00:00"
) -> SlurmJob:
    """
    Args:
        mesh_name: Name assigned to the primary mesh for this example.
                   A JobTrait can consist of multiple meshes, and
                   Monarch allows for re-attaching to ongoing jobs.
        num_nodes: Number of nodes allocated per mesh
        gpus_per_node: Number of GPUs per node in the mesh

        Note: SlurmJob is just one instance of a Monarch scheduler interface.
              Consult the JobTrait documentation to find one that's right for your usecase.
    """
    default_job_name = "monarch_titan"
    return SlurmJob(
        meshes={mesh_name: num_nodes},
        job_name=default_job_name,
        time_limit=time_limit,
        gpus_per_nodes=gpus_per_node,
        # ... additional args can be passed here
    )

步骤 2:定义 Trainer Actor#

现在我们创建一个 Monarch Actor 来封装 TorchTitan 的 Trainer。这是允许 TorchTitan 在 Monarch 的分布式环境中运行的关键抽象。

Monarch 亮点:Actor 模式提供了几个好处:

  1. 远程执行:带有 @endpoint 装饰器的函数可以远程调用

  2. 生命周期管理:Monarch 负责初始化、执行和清理

  3. 错误处理:异常会被正确地传播回客户端,支持渐进式错误处理

import torch
from monarch.actor import Actor, current_rank, endpoint
from monarch.utils import setup_env_for_distributed
from torchtitan.tools.logging import init_logger, logger
from torchtitan.train import Trainer


class TrainerActor(Actor):
    """
    Monarch Actor wrapper for TorchTitan's Trainer.

    This actor encapsulates a complete TorchTitan training process, handling
    initialization, training loop execution, and cleanup. Each instance runs
    on a single GPU in the distributed training job.

    The actor's lifetime:
        1. __init__: Initialize with job configuration
        2. start_training:
           Execute the training loop
           Destroy process group and release resources

    Attributes:
        job_config: TorchTitan configuration for this trainer
        uid: Unique identifier for logging (includes rank)
    """

    def __init__(self, job_config: "JobConfig") -> None:
        """
        Initialize the trainer actor.

        Args:
            job_config: TorchTitan JobConfig with training parameters
        """
        self.job_config = job_config

        # current_rank() provides access to this actor's rank in the process mesh
        self.rank = current_rank().rank
        self.uid = f"[trainer_{rank}]"

    @endpoint
    async def ping_rank(self) -> None:
        """
            A dummy logging function we will use for demonstration purposes.
        """
        logger.info(f"{self.uid} Ping!")

    @endpoint
    async def start_training(self) -> None:
        """
        Execute the TorchTitan training loop.

        This remote endpoint:
        1. Initializes TorchTitan's logger
        2. Creates a Trainer instance with the job configuration
        3. Runs the training loop
        4. Handles cleanup and error conditions

        The @endpoint decorator makes this method callable from the Monarch
        client, even though it runs on a remote GPU node.

        Raises:
            Exception: Any exception from TorchTitan training is propagated
                      back to the client
        """
        init_logger()
        trainer: Trainer | None = None
        try:
            # Initialize TorchTitan trainer
            trainer = Trainer(self.job_config)
            logger.info(f"{self.uid} initialized successfully and starting training")

            # Run the training loop
            trainer.train()

        except Exception as e:
            logger.error(f"{self.uid} training failed: {e}")
            if trainer:
                trainer.close()
            # Note: error is propagated back to the controller
            raise e

        else:
            # Training completed successfully
            trainer.close()
            logger.info(f"{self.uid} training completed successfully")

        finally:
            # Clean up distributed process group
            torch.distributed.destroy_process_group()
            logger.info(f"{self.uid} trainer cleaned up")

Actor 端点可以以多种模式调用。我们将在 步骤 4:执行训练工作流 中探讨一个具体的例子,但这里有一些常见用法的伪代码:

try:
    # where mesh0 is made of N nodes, each node having 8 GPUs
    proc_mesh = mesh0.spawn_procs({"gpus": 8})
    trainer_actors = proc_mesh.spawn("trainers", TrainerActor, ...)

    # Call on all ranks
    await trainer_actors.ping_rank.call()

    # Call-and-forget on all ranks
    trainer_actors.ping_rank.broadcast()

    # Call on ONE random rank
    await trainer_actors.ping_rank.choose()

    # Call on the first 3 ranks of node 0
    await trainer_actors.slice(hosts=0, gpus=slice(0, 3)).ping_rank.call()

except Exception as e:
    # handle SupervisionEvents from remote actor failures
    pass

远程 Actor 端点还可以利用 Python 原生断点,支持交互式调试会话。有关 Monarch 调试器的完整深入介绍,请参考文档

@endpoint
    async def ping_debuggable_rank(self) -> None:
        logger.info(f"{self.uid} Ping!")
        if self.rank == 0:
            breakpoint()
        logger.info(f"{self.uid} Pong!")

步骤 3:定义训练参数#

接下来,我们为训练作业和集群资源定义一些通用参数。此配置决定了训练的规模(节点数和 GPU 数)以及一些训练超参数。

from dataclasses import dataclass


@dataclass
class RunParams:
    """
    Configuration for cluster resources and training parameters.

    Attributes:
        training_steps: Number of training iterations to run
        model_config: Path to TorchTitan model configuration file
        tokenizer: Path to tokenizer directory
        dataset: Dataset to use for training (e.g., 'c4', 'c4_test')
        num_nodes: Number of compute nodes to request
        gpus_per_node: Number of GPUs per node

    Adjust these values based on your model size and available resources.
    """

    training_steps: int = 50
    model_config: str = "debug_model.toml"
    tokenizer: str = "tokenizer"
    dataset: str = "c4"
    num_nodes: int = 2
    gpus_per_node: int = 8

TorchTitan 使用 JobConfig 对象来控制训练的所有方面。这里我们创建一个函数,从 RunParams 解析此配置。

import os
from torchtitan.config import ConfigManager, JobConfig


def make_job_config() -> JobConfig:
    """
    Create a TorchTitan JobConfig from RunParams.

    This function constructs the complete training configuration, including
    parallelism settings, model architecture, and dataset paths
    """
    # Calculate total parallelism based on cluster size
    data_parallel_shard_degree = RunParams.num_nodes * RunParams.gpus_per_node
    output_path = "./outputs"
    # Construct paths relative to script directory
    script_dir = os.getcwd()

    # Build argument list for TorchTitan's ConfigManager
    # These override defaults from the model config file
    default_args = [
        "--job.config_file",
        os.path.join(script_dir, RunParams.model_config),
        "--model.tokenizer_path",
        os.path.join(script_dir, RunParams.tokenizer),
        "--parallelism.data_parallel_shard_degree",
        str(data_parallel_shard_degree),
        "--training.steps",
        str(RunParams.training_steps),
        "--training.dataset",
        RunParams.dataset,
        "--job.dump_folder",
        output_path,
        # continue to configure as needed
    ]
    config_manager = ConfigManager()
    job_config = config_manager.parse_args(default_args)
    return job_config

步骤 4:执行训练工作流#

在定义了所有组件后,我们现在来协调整个工作流。这是 Monarch 的强大功能最突出的地方。

Monarch 亮点:

  1. 交互式迭代:预留机器分配后,您可以调整逻辑并重新启动 Actor,而无需请求新资源。SLURM 的共享文件系统可确保框架/工作区更改在工作节点之间同步。

  2. 透明日志记录:所有来自远程工作节点的日志都会实时流式传输回您的客户端,使调试感觉就像在本地执行一样

工作流:

预留机器 → 创建进程网格 → 配置日志 → 启动 Actor → 训练 → 清理

async def execute_training() -> None:
    """
    Execute the complete distributed training workflow.
    """
    job_config = make_job_config()
    slurm_job = None
    mesh_name = "mesh0"
    try:
        # 1. Create a SLURM job with N nodes
        #    This leverages Monarch to reserve a persistent machine allocation
        slurm_job = create_slurm_job(mesh_name, RunParams.num_nodes, RunParams.gpus_per_node)
        job_state = slurm_job.state()

        # 2. Create a process mesh on the machine allocation
        #    This creates one process per GPU across all allocated nodes
        logger.info("Creating process mesh...")
        proc_mesh = job_state.mesh0.spawn_procs({"gpus": RunParams.gpus_per_node})

        # 3. Configure remote logging behavior
        #    - stream_to_client: Forward all remote logs to your local console
        #    - aggregate_window_sec: Batch logs for efficiency
        logger.info("Configuring logging...")
        await proc_mesh.logging_option(
            stream_to_client=True,
            # aggregate_window_sec=None  # Uncomment to disable log batching
        )

        # 4. Setup environment for torch.distributed
        #    This configures torch.distributed across all processes in the mesh
        logger.info("Setting up distributed environment...")
        await setup_env_for_distributed(proc_mesh)

        # 5. Spawn TrainerActor on each GPU
        #    Each process in the mesh creates its own TrainerActor instance
        logger.info("Spawning trainer actors...")
        trainer = proc_mesh.spawn(
            "trainer_actor",  # Name for the actor group
            TrainerActor,  # Actor class to instantiate
            job_config,  # Arguments to __init__
        )

        # 6. Execute the training job across all actors
        #    The .call() method invokes start_training() on all actors in parallel
        logger.info("Starting distributed training...")
        await trainer.start_training.call()

        logger.info("Training completed successfully!")

    except Exception as e:
        logger.error(f"Training workflow failed: {e}")

    finally:
        # Always clean up the machine allocation
        if slurm_job:
            await cleanup_job(slurm_job)

步骤 5:清理资源#

训练完成后(或者如果您已完成实验),重要的是通过终止 SLURM 作业来释放集群资源。

Monarch 亮点:虽然您可以在开发过程中将分配保持为活动状态以进行多次训练运行,但请务必记住释放集群资源。

async def cleanup_job(job: JobTrait) -> None:
    """
    This function cancels the SLURM job, releasing all reserved nodes back
    to the cluster for other users.

    Args:
        job: A JobTrait, like the one returned from create_slurm_job()

    Note:
        The job will also terminate automatically when the configured TTL
        is exceeded, but explicit cleanup is recommended for long-running
        notebooks or scripts.
    """
    job.kill()
    logger.info("Job terminated successfully")

步骤 6:运行完整流水线#

最后,我们将所有内容在一个主函数中连接起来,以启动工作流。

import asyncio


if __name__ == "__main__":
    """
    Run the complete workflow: reserve resources, train, and cleanup.
    """
    logger.info("Starting Monarch + TorchTitan Distributed Training")

    asyncio.run(execute_training())

    logger.info("Workflow completed!")

结论#

恭喜!在本教程中,您学习了如何将 Monarch 的 Actor 框架与 TorchTitan 结合用于可扩展的分布式训练。

进一步阅读

  • Monarch 还与 TorchFT 集成,为复制工作节点提供逐步容错。您可以在 TorchFT repo 中找到此集成的全面概念验证

  • 对于涵盖与本教程类似主题的交互式笔记本,请参考此 Monarch 示例