快捷方式

模型并行

DistributedModelParallel 是使用 TorchRec 优化进行分布式训练的主要 API。

class torchrec.distributed.model_parallel.DistributedModelParallel(module: Module, env: Optional[ShardingEnv] = None, device: Optional[device] = None, plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None, model_tracker_config: Optional[ModelTrackerConfig] = None)

模型并行功能的入口点。

参数:
  • module (nn.Module) – 要包装的模块。

  • env (Optional[ShardingEnv]) – 包含进程组的分片环境。

  • device (Optional[torch.device]) – 计算设备,默认为 cpu。

  • plan (Optional[ShardingPlan]) – 分片时使用的计划,默认为 EmbeddingShardingPlanner.collective_plan()

  • sharders (Optional[List[ModuleSharder[nn.Module]]]) – 可用于分片的 ModuleSharders,默认为 EmbeddingBagCollectionSharder()

  • init_data_parallel (bool) – 数据并行模块可以是惰性的,即它们延迟参数初始化直到第一次前向传播。传入 True 以延迟数据并行模块的初始化。执行第一次前向传播,然后调用 DistributedModelParallel.init_data_parallel()。

  • init_parameters (bool) – 为仍处于 meta device 上的模块初始化参数。

  • data_parallel_wrapper (Optional[DataParallelWrapper]) – 数据并行模块的自定义包装器。

  • model_tracker_config (Optional[DeltaTrackerConfig]) – 模型跟踪器的配置。

示例

@torch.no_grad()
def init_weights(m):
    if isinstance(m, nn.Linear):
        m.weight.fill_(1.0)
    elif isinstance(m, EmbeddingBagCollection):
        for param in m.parameters():
            init.kaiming_normal_(param)

m = MyModel(device='meta')
m = DistributedModelParallel(m)
m.apply(init_weights)
copy(device: device) DistributedModelParallel

通过调用每个模块的自定义复制过程,递归地将子模块复制到新设备,因为有些模块需要使用原始引用(例如用于推理的 ShardedModule)。

forward(*args, **kwargs) Any

定义每次调用时执行的计算。

所有子类都应重写此方法。

注意

虽然前向传播的实现需要在该函数内定义,但用户应该在之后调用 Module 实例而不是这个函数,因为前者负责运行注册的钩子,而后者则会静默地忽略它们。

get_delta(consumer: Optional[str] = None) Dict[str, DeltaRows]

返回给定消费者的增量行。

get_model_tracker() ModelDeltaTracker

如果模型跟踪器存在,则返回它。

init_data_parallel() None

请参阅 init_data_parallel 构造函数参数以了解用法。多次调用此方法是安全的。

load_state_dict(state_dict: OrderedDict[str, Tensor], prefix: str = '', strict: bool = True) _IncompatibleKeys

将参数和缓冲区从 state_dict 复制到此模块及其后代模块。

如果 strictTrue,则 state_dict 的键必须与此模块的 state_dict() 函数返回的键完全匹配。

警告

如果 assignTrue,则优化器必须在调用 load_state_dict 后创建,除非 get_swap_module_params_on_conversion()True

参数:
  • state_dict (dict) – 包含参数和持久 buffer 的字典。

  • strict (bool, optional) – 是否严格强制 state_dict 中的键与此模块的 state_dict() 函数返回的键匹配。默认为 True

  • assign (bool, optional) – 当设置为 False 时,将保留当前模块中张量的属性;而设置为 True 时,将保留 state dict 中张量的属性。唯一的例外是 Parameterrequires_grad 字段,将保留模块中的值。默认为 False

返回:

  • missing_keys 是一个包含此模块期望但

    在提供的 state_dict 中缺失的任何键的字符串列表。

  • unexpected_keys 是一个字符串列表,包含此模块

    不期望但在提供的 state_dict 中存在的键。

返回类型:

包含 missing_keysunexpected_keys 字段的 NamedTuple

注意

如果参数或缓冲区被注册为 None 且其对应的键存在于 state_dict 中,load_state_dict() 将引发 RuntimeError

property module: Module

用于直接访问分片模块的属性,该模块不会被 DDP、FSDP、DMP 或任何其他并行包装器包装。

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

返回模块缓冲区上的迭代器,同时生成缓冲区的名称和缓冲区本身。

参数:
  • prefix (str) – 为所有 buffer 名称添加前缀。

  • recurse (bool, optional) – 如果为 True,则会生成此模块及其所有子模块的 buffers。否则,仅生成此模块直接成员的 buffers。默认为 True。

  • remove_duplicate (bool, optional) – 是否在结果中删除重复的 buffers。默认为 True。

产生:

(str, torch.Tensor) – 包含名称和缓冲区的元组

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

返回模块参数的迭代器,同时生成参数的名称和参数本身。

参数:
  • prefix (str) – 为所有参数名称添加前缀。

  • recurse (bool) – 如果为 True,则会生成此模块及其所有子模块的参数。否则,仅生成此模块直接成员的参数。

  • remove_duplicate (bool, optional) – 是否在结果中删除重复的参数。默认为 True。

产生:

(str, Parameter) – 包含名称和参数的元组

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
reshard(sharded_module_fqn: str, changed_shard_to_params: Dict[str, ParameterSharding]) None

在 DMP 中重新分片一个已分片的模块,给定一组要更改放置的 ParameterShardings。

此方法允许您动态更改特定模块的分片策略,而无需重新创建整个 DMP。它尤其适用于:1. 适应训练期间不断变化的需求 2. 实现渐进式分片策略 3. 重新平衡设备间的负载 4. A/B 测试不同的分片计划

参数:
  • path_to_sharded_module (str) – DMP 中分片模块的路径。例如,“sparse.ebc”。

  • changed_shard_to_params (Dict[str, ParameterSharding]) – 一个映射参数名称到其新的 ParameterSharding 配置的字典。仅包含需要移动的分片。

示例

``` # 原始分片计划可能将表分片到 2 个 GPU 上 original_plan = {

“table_0’: ParameterSharding(

sharding_type=”table_wise”, ranks=[0, 1, 2, 3], sharding_spec=EnumerableShardingSpec(…)

)

}

# 新的分片计划分片到 4 个 GPU 上 new_plan = {

“weight”: ParameterSharding(

sharding_type=”table_wise”, ranks=[0, 1, 2, 3], sharding_spec=EnumerableShardingSpec(…)

)

}

# 用于仅选择原始计划和新计划之间差异的辅助函数 changed_sharding_params = output_sharding_plan_delta(new_plan)

# 重分片模块并重新分发张量 model.reshard(“embedding_module”, changed_sharding_params) ```

注意事项

  • 模块的分片器必须实现 reshard 方法

  • 重分片涉及在设备间重新分发张量数据,这可能成本很高

  • 重分片后,将为该模块维护优化器状态

  • 分片计划将更新以反映新配置

state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

返回一个字典,其中包含对模块整个状态的引用。

参数和持久缓冲区(例如,运行平均值)都包含在内。键是相应的参数和缓冲区名称。设置为 None 的参数和缓冲区不包含在内。

注意

返回的对象是浅拷贝。它包含对模块参数和缓冲区的引用。

警告

目前 state_dict() 还接受位置参数,用于按顺序传递 destinationprefixkeep_vars。但是,这将被弃用,并且在未来的版本中将强制使用关键字参数。

警告

请避免使用参数 destination,因为它不是为最终用户设计的。

参数:
  • destination (dict, optional) – 如果提供,模块的状态将更新到字典中,并返回相同的对象。否则,将创建一个 OrderedDict 并返回。默认为 None

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''

  • keep_vars (bool, optional) – 默认情况下,state dict 中返回的 Tensor 会与 autograd 分离。如果设置为 True,则不会执行分离。默认为 False

返回:

包含模块整体状态的字典

返回类型:

dict

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源