快捷方式

TorchRec 概念

在本节中,我们将学习 TorchRec 的关键概念,它旨在使用 PyTorch 优化大规模推荐系统。我们将详细了解每个概念的工作原理,并回顾它如何与 TorchRec 的其余部分一起使用。

TorchRec 的模块具有特定的输入/输出数据类型,以有效地表示稀疏特征,包括:

  • JaggedTensor: 一个包装器,用于单个稀疏特征的长度/偏移量和值张量。

  • KeyedJaggedTensor: 有效地表示多个稀疏特征,可以将其视为多个 JaggedTensor

  • KeyedTensor: 一个 torch.Tensor 的包装器,允许通过键访问张量值。

为了实现高性能和高效率,标准的 torch.Tensor 在表示稀疏数据方面效率很低。TorchRec 引入这些新的数据类型,因为它们提供了稀疏输入数据的有效存储和表示。正如您稍后将看到的,KeyedJaggedTensor 使得在分布式环境中输入数据的通信非常高效,从而成为 TorchRec 提供的关键性能优势之一。

在端到端训练循环中,TorchRec 包括以下主要组件:

  • Planner: 接收嵌入表配置、环境设置,并生成模型优化的分片计划。

  • Sharder: 根据分片计划使用不同的分片策略(包括数据并行、表级、行级、表级-行级、列级和表级-列级分片)对模型进行分片。

  • DistributedModelParallel: 结合了 Sharder、Optimizer,并提供了以分布式方式训练模型的入口点。

JaggedTensor

一个 JaggedTensor 通过长度、值和偏移量表示稀疏特征。它之所以被称为“jagged”(锯齿状),是因为它能有效地表示具有可变长度序列的数据。相比之下,标准的 torch.Tensor 假设每个序列都具有相同的长度,而这在现实世界的数据中通常不是这种情况。 JaggedTensor 能够表示此类数据而无需填充,从而使其效率极高。

关键组件

  • Lengths:一个整数列表,表示每个实体的元素数量。

  • Offsets:一个整数列表,表示展平的值张量中每个序列的起始索引。它们提供了长度的替代方案。

  • Values:一个 1D 张量,包含每个实体的实际值,连续存储。

以下是一个简单的示例,展示了每个组件的外观:

# User interactions:
# - User 1 interacted with 2 items
# - User 2 interacted with 3 items
# - User 3 interacted with 1 item
lengths = [2, 3, 1]
offsets = [0, 2, 5]  # Starting index of each user's interactions
values = torch.Tensor([101, 102, 201, 202, 203, 301])  # Item IDs interacted with
jt = JaggedTensor(lengths=lengths, values=values)
# OR
jt = JaggedTensor(offsets=offsets, values=values)

KeyedJaggedTensor

一个 KeyedJaggedTensor 通过引入键(通常是特征名称)来扩展 JaggedTensor 的功能,以标记不同的特征组,例如用户特征和项特征。这是 EmbeddingBagCollectionEmbeddingCollectionforward 中使用的数据类型,因为它们用于在一张表中表示多个特征。

一个 KeyedJaggedTensor 具有一个隐含的批次大小,即 lengths 张量的长度除以键的数量。下面的示例具有批次大小为 2(4 个长度除以 2 个键)。与 JaggedTensor 类似,offsetslengths 的功能相同。您还可以通过访问 KeyedJaggedTensor 中的键来访问特征的 lengthsoffsetsvalues

keys = ["user_features", "item_features"]
# Lengths of interactions:
# - User features: 2 users, with 2 and 3 interactions respectively
# - Item features: 2 items, with 1 and 2 interactions respectively
lengths = [2, 3, 1, 2]
values = torch.Tensor([11, 12, 21, 22, 23, 101, 201, 202])
# Create a KeyedJaggedTensor
kjt = KeyedJaggedTensor(keys=keys, lengths=lengths, values=values)
# Access the features by key
print(kjt["user_features"])
# Outputs user features
print(kjt["item_features"])

Planner

TorchRec Planner 帮助确定模型最佳的分片配置。它评估多种嵌入表分片可能性并优化性能。Planner 执行以下操作:

  • 评估硬件的内存约束。

  • 根据内存获取(例如,嵌入查找)估算计算需求。

  • 处理数据特定因素。

  • 考虑其他硬件特定因素,例如带宽,以生成最优分片计划。

为了确保准确考虑这些因素,Planner 可以整合有关嵌入表、约束、硬件信息和拓扑的数据,以帮助生成最优计划。

嵌入表的 Sharding

TorchRec Sharder 为各种用例提供了多种分片策略。我们概述了一些分片策略及其工作原理,以及它们的优点和局限性。通常,我们建议使用 TorchRec Planner 为您生成分片计划,因为它会为模型中的每个嵌入表找到最佳分片策略。

每种分片策略都决定了如何进行表分割,表是否应该被切分以及如何切分,是否保留某些表的一个或几个副本等。分片结果中的表的每个部分,无论是单个嵌入表还是其一部分,都称为一个分片(shard)。

Visualizing the difference of sharding types offered in TorchRec

图 1:可视化 TorchRec 提供的不同分片方案下的表分片放置

以下是 TorchRec 中所有可用分片类型的列表:

  • 表级 (TW):顾名思义,嵌入表被保留为整个部分并放置在一个 rank 上。

  • 列级 (CW):表沿着 emb_dim 维度进行分割,例如,emb_dim=256 被分割成 4 个分片:[64, 64, 64, 64]

  • 行级 (RW):表沿着 hash_size 维度进行分割,通常在所有 rank 之间平均分割。

  • 表级-行级 (TWRW):表被放置在一个主机上,并在该主机上的 rank 之间按行分割。

  • 网格分片 (GS):一个表被 CW 分片,并且每个 CW 分片被放置在主机上的 TWRW。

  • 数据并行 (DP):每个 rank 保留一个表副本。

分片后,模块被转换为其分片版本,在 TorchRec 中称为 ShardedEmbeddingCollectionShardedEmbeddingBagCollection。这些模块处理输入数据的通信、嵌入查找和梯度。

使用 TorchRec Sharded 模块进行分布式训练

有许多分片策略可用,我们如何确定使用哪一种?每种分片方案都有其成本,该成本与模型大小和 GPU 数量一起决定了哪种分片策略最适合某个模型。

在没有分片的情况下,其中每个 GPU 保留嵌入表的副本(DP),主要成本是计算,其中每个 GPU 在前向传递中查找其内存中的嵌入向量,并在反向传递中更新梯度。

通过分片,会增加通信成本:每个 GPU 需要向其他 GPU 请求嵌入向量查找,并通信计算出的梯度。这通常称为 all2all 通信。在 TorchRec 中,对于给定 GPU 上的输入数据,我们确定每个数据部分的嵌入分片位于何处,并将其发送到目标 GPU。然后,该目标 GPU 将嵌入向量返回给原始 GPU。在反向传递中,梯度被发送回目标 GPU,并且分片会根据优化器进行相应更新。

如上所述,分片要求我们通信输入数据和嵌入查找。TorchRec 在三个主要阶段处理此问题,我们将此称为分片嵌入模块前向传播,它用于 TorchRec 模型的训练和推理。

  • 特征 All to All/输入分布 (input_dist)

    • 将输入数据(以 KeyedJaggedTensor 的形式)通信到包含相关嵌入表分片的相应设备。

  • 嵌入查找

    • 使用在特征 all to all 交换后形成的新输入数据进行嵌入查找。

  • 嵌入 All to All/输出分布 (output_dist)

    • 将嵌入查找数据通信回请求它的相应设备(与设备接收的输入数据一致)。

  • 反向传递的操作顺序相反。

下图展示了其工作原理:

Visualizing the forward pass including the input_dist, lookup, and output_dist of a sharded TorchRec module

图 2:表级分片表的正向传递,包括分片 TorchRec 模块的 input_dist、lookup 和 output_dist

DistributedModelParallel

以上所有内容都汇总到了 TorchRec 用于分片和集成计划的主要入口点。总的来说,DistributedModelParallel 执行以下操作:

  • 通过设置进程组和分配设备类型来初始化环境。

  • 如果未提供 Sharder,则使用默认 Sharder,默认 Sharder 包括 EmbeddingBagCollectionSharder

  • 接收提供的分片计划,如果未提供,则生成一个。

  • 创建模块的分片版本并替换原始模块,例如,将 EmbeddingCollection 转换为 ShardedEmbeddingCollection

  • 默认情况下,使用 DistributedDataParallel 包装 DistributedModelParallel,使模块既是模型并行又是数据并行。

优化器

TorchRec 模块提供了一个无缝的 API,用于在训练中融合反向传递和优化器步骤,从而显著提高性能并减少内存使用,同时还可以为不同的模型参数分配不同的优化器。

Visualizing fusing of optimizer in backward to update sparse embedding table

图 3:将嵌入反向传播与稀疏优化器融合

推理

推理环境与训练环境不同,它们对性能和模型大小非常敏感。TorchRec 推理优化的两个关键区别是:

  • 量化: 推理模型被量化以降低延迟和模型大小。此优化允许我们使用尽可能少的设备进行推理以最小化延迟。

  • C++ 环境: 为了进一步降低延迟,模型在 C++ 环境中运行。

TorchRec 提供以下功能将 TorchRec 模型转换为可用于推理的状态:

  • 用于量化模型的 API,包括使用 FBGEMM TBE 的自动优化。

  • 用于分布式推理的嵌入分片。

  • 将模型编译为 TorchScript(在 C++ 中兼容)。

另请参阅

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源