分布式数据并行
此模块实现了一个分布式数据并行包装器,该包装器与管理器配合使用以提供容错能力。
-
class torchft.ddp.DistributedDataParallel(manager: Manager, module: Module, **kwargs: object)[source]
Bases: DistributedDataParallel
这是已打补丁的分布式数据并行实现,使其与 torchft 兼容。
重要说明
这要求在第 0 步使用外部机制而不是内部广播(torchft.Manager 将执行此操作)来同步状态。
使用 DDP 的非基本功能可能会导致您的模型着火,因为它们尚未与 torchft 进行测试。
此功能不进行任何健全性检查,例如验证工作程序之间参数的大小是否相同。
-
class torchft.ddp.PureDistributedDataParallel(manager: Manager, module: Module)[source]
Bases: Module
DDP 包装器的纯 Python 重实现。
我们建议使用 DistributedDataParallel 而不是此类。
此方法为每个梯度张量调用一次 allreduce,并且不使用 reducer。这对于实际模型来说可能非常慢。
-
forward(*args: object) → object[source]
定义每次调用时执行的计算。
所有子类都应重写此方法。
注意
尽管 forward pass 的实现需要在该函数内定义,但应在之后调用 Module
实例而不是此实例,因为前者负责运行已注册的钩子,而后者则会默默地忽略它们。