评价此页

Checkpointing#

此模块实现了用于 checkpointing 和从 checkpoint 恢复训练的方法。

class torchft.checkpointing.CheckpointTransport[source]#

Bases: Generic[T], ABC

disallow_checkpoint() None[source]#

在 send_checkpoint 后调用,以等待 checkpoint 发送完成。

一旦此函数返回,state_dict 可能会被修改,因此不应再发送任何数据。

abstract metadata() str[source]#

返回一个字符串,该字符串将由远程 CheckpointTransport 用于获取 checkpoint。

abstract recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T[source]#

从指定 rank 接收 checkpoint。

参数
  • src_rank – 要从中接收 checkpoint 的 rank

  • metadata – 远程 CheckpointTransport 返回的元数据

  • step – 要接收的步数

  • timeout – 等待 checkpoint 的超时时间

abstract send_checkpoint(dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta) None[source]#

发送 checkpoint,仅当有 rank 落后时调用。

这可能是异步的。

参数
  • dst_ranks – 要发送到的 ranks

  • step – 要发送的步数

  • state_dict – 要发送的状态字典

  • timeout – 等待 checkpoint 发送完成的超时时间

shutdown(wait: bool = True) None[source]#

用于关闭 checkpoint transport。

参数

wait – 是否等待 transport 关闭

class torchft.checkpointing.HTTPTransport(timeout: timedelta, num_chunks: int)[source]#

Bases: CheckpointTransport[T]

这是一个 HTTP 服务器,可用于在 worker 之间传输 checkpoints。

这使得 worker 能够通过从现有 worker 获取当前权重来快速恢复。

参数
  • timeout – HTTP 请求的超时时间

  • num_chunks – 将 checkpoint 分成的块数(0 表示不分块)

address() str[source]#

返回从该服务器获取 checkpoint 的 HTTP 地址。必须将步数附加到地址的末尾。

格式: http://host:port/checkpoint/1234

返回

一个 HTTP 地址

allow_checkpoint(step: int) None[source]#

允许服务指定步数的 checkpoint。

参数

step – 要服务的步数

disallow_checkpoint() None[source]#

禁止服务 checkpoint。

所有请求都将阻塞,直到调用 allow_checkpoint。

metadata() str[source]#

返回一个字符串,该字符串将由远程 CheckpointTransport 用于获取 checkpoint。

recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T[source]#

从指定 rank 接收 checkpoint。

参数
  • src_rank – 要从中接收 checkpoint 的 rank

  • metadata – 远程 CheckpointTransport 返回的元数据

  • step – 要接收的步数

  • timeout – 等待 checkpoint 的超时时间

send_checkpoint(dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta) None[source]#

发送 checkpoint,仅当有 rank 落后时调用。

这可能是异步的。

参数
  • dst_ranks – 要发送到的 ranks

  • step – 要发送的步数

  • state_dict – 要发送的状态字典

  • timeout – 等待 checkpoint 发送完成的超时时间

shutdown(wait: bool = True) None[source]#

关闭服务器。