- class torchft.data.DistributedSampler(dataset: Dataset, replica_rank: int, num_replica_groups: int, group_rank: Optional[int] = None, num_replicas: Optional[int] = None, **kwargs: object)[source]#
Bases:
DistributedSampler
DistributedSampler 扩展了标准的 PyTorch DistributedSampler,增加了一个 num_replica_groups 参数,用于在容错副本组之间分片数据。
torchft 无法提前知道副本组的数量,因此我们需要将其设置为最大值。
当与 torchft 一起使用时,此采样器本质上是有损的。torchft 会偶尔丢弃批次,并且如果一个副本组宕机,该组的示例将永远不会被使用。这可能导致在使用小型数据集时出现不平衡。
这将把输入数据集分片为
num_replicas*num_replica_group
个分片。每个分片排名通过以下方式计算:
rank + num_replicas*replica_rank
所有工作节点上的 num_replicas 和 replica_rank 必须相同。