评价此页

Data#

本模块提供了实现容错数据加载器的辅助类。

我们建议使用 torchdata 的 StatefulDataLoader 来频繁地为每个副本的数据加载器创建检查点,以避免重复批次。

class torchft.data.DistributedSampler(dataset: Dataset, replica_rank: int, num_replica_groups: int, group_rank: Optional[int] = None, num_replicas: Optional[int] = None, **kwargs: object)[source]#

Bases: DistributedSampler

DistributedSampler 扩展了标准的 PyTorch DistributedSampler,增加了一个 num_replica_groups 参数,用于在容错副本组之间分片数据。

torchft 无法提前知道副本组的数量,因此我们需要将其设置为最大值。

当与 torchft 一起使用时,此采样器本质上是有损的。torchft 会偶尔丢弃批次,并且如果一个副本组宕机,该组的示例将永远不会被使用。这可能导致在使用小型数据集时出现不平衡。

这将把输入数据集分片为 num_replicas*num_replica_group 个分片。

每个分片排名通过以下方式计算: rank + num_replicas*replica_rank

所有工作节点上的 num_replicas 和 replica_rank 必须相同。