Checkpointing#
此模块实现了用于 checkpointing 和从 checkpoint 恢复训练的方法。
- class torchft.checkpointing.CheckpointTransport[source]#
Bases:
Generic
[T
],ABC
- disallow_checkpoint() None [source]#
在 send_checkpoint 后调用,以等待 checkpoint 发送完成。
一旦此函数返回,state_dict 可能会被修改,因此不应再发送任何数据。
- abstract recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T [source]#
从指定 rank 接收 checkpoint。
- 参数
src_rank – 要从中接收 checkpoint 的 rank
metadata – 远程 CheckpointTransport 返回的元数据
step – 要接收的步数
timeout – 等待 checkpoint 的超时时间
- class torchft.checkpointing.HTTPTransport(timeout: timedelta, num_chunks: int)[source]#
Bases:
CheckpointTransport
[T
]这是一个 HTTP 服务器,可用于在 worker 之间传输 checkpoints。
这使得 worker 能够通过从现有 worker 获取当前权重来快速恢复。
- 参数
timeout – HTTP 请求的超时时间
num_chunks – 将 checkpoint 分成的块数(0 表示不分块)
- address() str [source]#
返回从该服务器获取 checkpoint 的 HTTP 地址。必须将步数附加到地址的末尾。
格式: http://host:port/checkpoint/1234
- 返回
一个 HTTP 地址
- recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T [source]#
从指定 rank 接收 checkpoint。
- 参数
src_rank – 要从中接收 checkpoint 的 rank
metadata – 远程 CheckpointTransport 返回的元数据
step – 要接收的步数
timeout – 等待 checkpoint 的超时时间