自定义#
创建于:2021年5月4日 | 最后更新:2021年5月4日
本节介绍如何自定义 TorchElastic 以满足您的需求。
启动器(Launcher)#
随 TorchElastic 提供的启动器程序应足以满足大多数用例(请参阅 torchrun (弹性启动))。您可以通过编程方式创建代理(Agent)并为其传递工作进程规范(specs)来实现自定义启动器,如下所示。
# my_launcher.py
if __name__ == "__main__":
args = parse_args(sys.argv[1:])
rdzv_handler = RendezvousHandler(...)
spec = WorkerSpec(
local_world_size=args.nproc_per_node,
fn=trainer_entrypoint_fn,
args=(trainer_entrypoint_fn args.fn_args,...),
rdzv_handler=rdzv_handler,
max_restarts=args.max_restarts,
monitor_interval=args.monitor_interval,
)
agent = LocalElasticAgent(spec, start_method="spawn")
try:
run_result = agent.run()
if run_result.is_failed():
print(f"worker 0 failed with: run_result.failures[0]")
else:
print(f"worker 0 return value is: run_result.return_values[0]")
except Exception ex:
# handle exception
集合点处理器(Rendezvous Handler)#
要实现您自己的集合点(rendezvous),请继承 torch.distributed.elastic.rendezvous.RendezvousHandler 并实现其方法。
警告
实现集合点处理器比较复杂。在开始之前,请确保您完全理解集合点的特性。有关更多信息,请参阅 集合点 (Rendezvous)。
实现完成后,您可以在创建代理时将自定义的集合点处理器传递给工作进程规范。
spec = WorkerSpec(
rdzv_handler=MyRendezvousHandler(params),
...
)
elastic_agent = LocalElasticAgent(spec, start_method=start_method)
elastic_agent.run(spec.role)
指标处理器(Metric Handler)#
TorchElastic 会发出平台级指标(请参阅 指标)。默认情况下,指标会发送到 /dev/null,因此您将看不到它们。要将指标推送到您基础设施中的指标处理服务,请实现 torch.distributed.elastic.metrics.MetricHandler 并在您的自定义启动器中进行 配置 (configure)。
# my_launcher.py
import torch.distributed.elastic.metrics as metrics
class MyMetricHandler(metrics.MetricHandler):
def emit(self, metric_data: metrics.MetricData):
# push metric_data to your metric sink
def main():
metrics.configure(MyMetricHandler())
spec = WorkerSpec(...)
agent = LocalElasticAgent(spec)
agent.run()
事件处理器(Events Handler)#
TorchElastic 支持事件记录(请参阅 事件)。事件模块定义了 API,允许您记录事件并实现自定义的 EventHandler。EventHandler 用于将 torchelastic 执行期间产生的事件发布到不同的源,例如 AWS CloudWatch。默认情况下,它使用 torch.distributed.elastic.events.NullEventHandler,该处理器会忽略事件。要配置自定义事件处理器,您需要实现 torch.distributed.elastic.events.EventHandler 接口,并在您的自定义启动器中进行 配置 (configure)。
# my_launcher.py
import torch.distributed.elastic.events as events
class MyEventHandler(events.EventHandler):
def record(self, event: events.Event):
# process event
def main():
events.configure(MyEventHandler())
spec = WorkerSpec(...)
agent = LocalElasticAgent(spec)
agent.run()