注意
前往末尾 下载完整的示例代码。
TorchRec 简介#
创建日期: 2024年10月02日 | 最后更新: 2025年07月10日 | 最后验证: 2024年10月02日
TorchRec 是一个针对使用嵌入式(embeddings)构建可扩展、高效推荐系统的 PyTorch 库。本教程将引导您完成安装过程,介绍嵌入式的概念,并强调其在推荐系统中的重要性。教程将提供使用 PyTorch 和 TorchRec 实现嵌入式的实践演示,重点关注通过分布式训练和高级优化来处理大型嵌入表。
嵌入式的基本原理及其在推荐系统中的作用
如何在 PyTorch 环境中设置 TorchRec 来管理和实现嵌入式
探索将大型嵌入表分布到多个 GPU 上的高级技术
PyTorch v2.5 或更高版本,以及 CUDA 11.8 或更高版本
Python 3.9 或更高版本
安装依赖项#
在 Google Colab 中运行本教程之前,请确保安装以下依赖项
!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
!pip3 install torchmetrics==1.0.3
!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121
注意
如果您在 Google Colab 中运行,请确保切换到 GPU 运行时类型。有关更多信息,请参阅 启用 CUDA
嵌入式(Embeddings)#
在构建推荐系统时,类别特征(categorical features)通常具有巨大的基数(cardinality),例如帖子、用户、广告等等。
为了表示这些实体和模拟这些关系,我们使用 嵌入式(embeddings)。在机器学习中,嵌入式是在高维空间中表示实数向量,用于表示单词、图像或用户等复杂数据中的含义。
推荐系统中的嵌入式#
现在您可能会问,这些嵌入式是如何生成的?嗯,嵌入式表示为 嵌入表(Embedding Table) 中的单独行,也称为嵌入权重。之所以如此,是因为嵌入式或嵌入表权重与模型中的其他权重一样,都通过梯度下降进行训练!
嵌入表只是一个用于存储嵌入式的大型矩阵,具有两个维度(B, N),其中
B 是表中存储的嵌入式数量
N 是每个嵌入式的维度数(N 维嵌入式)。
嵌入表的输入代表嵌入查找,用于检索特定索引或行的嵌入式。在许多大型系统中使用的推荐系统中,唯一的 ID 不仅用于特定用户,还跨越帖子和广告等实体,用作相应嵌入表的查找索引!
嵌入式在推荐系统中通过以下过程进行训练
输入/查找索引被作为唯一 ID 输入模型。ID 会被哈希到嵌入表的总大小,以防止出现 ID > 行数的问题。
然后检索嵌入式并进行 池化(pooling),例如取嵌入式的总和或平均值。这是必需的,因为每个示例的嵌入式数量可能不同,而模型需要一致的形状。
嵌入式与模型的其余部分一起用于生成预测,例如广告的 点击率 (CTR)。
根据预测和示例的标签计算损失,并且 模型的所有权重都通过梯度下降和反向传播进行更新,包括与该示例关联的嵌入权重。
这些嵌入式对于表示用户、帖子和广告等类别特征至关重要,以便捕获关系并做出好的推荐。 深度学习推荐模型 (DLRM) 论文更详细地讨论了在推荐系统中使嵌入表的技术细节。
本教程介绍了嵌入式的概念,展示了 TorchRec 特定的模块和数据类型,并说明了 TorchRec 的分布式训练是如何工作的。
import torch
PyTorch 中的嵌入式#
在 PyTorch 中,我们有以下类型的嵌入式:
torch.nn.Embedding:一种嵌入表,其前向传播返回嵌入式本身。torch.nn.EmbeddingBag:嵌入表,其前向传播返回然后被池化的嵌入式,例如总和或平均值,也称为 池化嵌入式 (Pooled Embeddings)。
在本节中,我们将简要介绍通过将索引传递到表中来执行嵌入查找。
num_embeddings, embedding_dim = 10, 4
# Initialize our embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights:", weights)
# Pass in pre-generated weights just for example, typically weights are randomly initialized
embedding_collection = torch.nn.Embedding(
num_embeddings, embedding_dim, _weight=weights
)
embedding_bag_collection = torch.nn.EmbeddingBag(
num_embeddings, embedding_dim, _weight=weights
)
# Print out the tables, we should see the same weights as above
print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)
# Lookup rows (ids for embedding ids) from the embedding tables
# 2D tensor with shape (batch_size, ids for each batch)
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)
embeddings = embedding_collection(ids)
# Print out the embedding lookups
# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)
# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
pooled_embeddings = embedding_bag_collection(ids)
print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)
# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))
恭喜!您现在对如何使用嵌入表有了基本的了解——这是现代推荐系统的基础之一!这些表代表实体及其关系。例如,给定用户与他们喜欢过的页面和帖子的关系。
TorchRec 功能概述#
在上面一节中,我们学习了如何使用嵌入表,这是现代推荐系统的基础之一!这些表代表实体和关系,例如用户、页面、帖子等。鉴于这些实体不断增加,通常会应用 哈希 函数来确保 ID 在特定嵌入表的范围内。但是,为了表示大量实体并减少哈希冲突,这些表可能会变得非常庞大(想想广告的数量)。事实上,这些表可能变得如此庞大,以至于即使有 80GB 内存也无法容纳在单个 GPU 上。
为了训练具有庞大嵌入表的模型,需要将这些表分片到 GPU 上,这会带来全新的并行化和优化问题和机遇。幸运的是,我们有 TorchRec 库 <https://docs.pytorch.ac.cn/torchrec/overview.html>`__,它已经遇到、整合并解决了其中许多问题。TorchRec 是一个 提供大规模分布式嵌入式原语的库。
接下来,我们将探索 TorchRec 库的主要功能。我们将从 torch.nn.Embedding 开始,然后扩展到自定义 TorchRec 模块,探索分布式训练环境,为嵌入式生成分片计划,查看固有的 TorchRec 优化,并将模型扩展为可在 C++ 中进行推理。下面是本节内容的快速大纲
TorchRec 模块和数据类型
分布式训练、分片和优化
让我们开始导入 TorchRec
import torchrec
本节将介绍 TorchRec 模块和数据类型,包括 EmbeddingCollection、EmbeddingBagCollection、JaggedTensor、KeyedJaggedTensor、KeyedTensor 等实体。
从 EmbeddingBag 到 EmbeddingBagCollection#
我们已经了解了 torch.nn.Embedding 和 torch.nn.EmbeddingBag。TorchRec 通过创建嵌入式集合来扩展这些模块,换句话说,就是可以拥有多个嵌入表的模块,即 EmbeddingCollection 和 EmbeddingBagCollection。我们将使用 EmbeddingBagCollection 来表示一组嵌入式包。
在下面的示例代码中,我们创建了一个 EmbeddingBagCollection (EBC),其中包含两个嵌入式包,一个代表 产品,一个代表 用户。每个表,product_table 和 user_table,都由一个 64 维、大小为 4096 的嵌入式表示。
ebc = torchrec.EmbeddingBagCollection(
device="cpu",
tables=[
torchrec.EmbeddingBagConfig(
name="product_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["product"],
pooling=torchrec.PoolingType.SUM,
),
torchrec.EmbeddingBagConfig(
name="user_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["user"],
pooling=torchrec.PoolingType.SUM,
),
],
)
print(ebc.embedding_bags)
让我们检查 EmbeddingBagCollection 的前向方法以及模块的输入和输出。
import inspect
# Let's look at the ``EmbeddingBagCollection`` forward method
# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
print(inspect.getsource(ebc.forward))
TorchRec 输入/输出数据类型#
TorchRec 为其模块的输入和输出提供了不同的数据类型:JaggedTensor、KeyedJaggedTensor 和 KeyedTensor。现在您可能会问,为什么要创建新的数据类型来表示稀疏特征?要回答这个问题,我们必须了解稀疏特征在代码中是如何表示的。
稀疏特征也称为 id_list_feature 和 id_score_list_feature,它们是将被用作嵌入表索引以检索该 ID 的嵌入式的 ID。举一个非常简单的例子,想象一下一个稀疏特征是用户与之交互过的广告。输入本身将是一组用户与之交互过的广告 ID,检索到的嵌入式将是对这些广告的语义表示。在代码中表示这些特征的棘手之处在于,在每个输入示例中,ID 的数量是可变的。有一天,用户可能只与一个广告互动,而第二天他们可能与三个广告互动。
下面展示了一个简单的表示,其中有一个 lengths 张量,表示一个批次中的每个示例有多少索引,以及一个包含索引本身的 values 张量。
# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])
# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])
接下来,让我们看看偏移量以及每个批次包含的内容。
# Lengths can be converted to offsets for easy indexing of values
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)
print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print(
"Second Batch: ",
id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
)
from torchrec import JaggedTensor
# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)
# Automatically compute offsets from lengths
print("Offsets: ", jt.offsets())
# Convert to list of values
print("List of Values: ", jt.to_dense())
# ``__str__`` representation
print(jt)
from torchrec import KeyedJaggedTensor
# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!
product_jt = JaggedTensor(
values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))
# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})
# Look at our feature keys for the ``KeyedJaggedTensor``
print("Keys: ", kjt.keys())
# Look at the overall lengths for the ``KeyedJaggedTensor``
print("Lengths: ", kjt.lengths())
# Look at all values for ``KeyedJaggedTensor``
print("Values: ", kjt.values())
# Can convert ``KeyedJaggedTensor`` to dictionary representation
print("to_dict: ", kjt.to_dict())
# ``KeyedJaggedTensor`` string representation
print(kjt)
# Q2: What are the offsets for the ``KeyedJaggedTensor``?
# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
result = ebc(kjt)
result
# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
print(result.keys())
# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
# meaning embeddings for both features are of size 64. 64 + 64 = 128
print(result.values().shape)
# Nice to_dict method to determine the embeddings that belong to each feature
result_dict = result.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
恭喜!您现在已经了解了 TorchRec 模块和数据类型。为您一路走到这里给自己鼓掌。接下来,我们将学习分布式训练和分片。
分布式训练和分片#
现在我们对 TorchRec 模块和数据类型有了初步了解,是时候更进一步了。
请记住,TorchRec 的主要目的是提供分布式嵌入式的原语。到目前为止,我们只在单个设备上处理了嵌入表。这之所以可行,是因为嵌入表的大小一直很小,但在生产环境中通常不是这样。嵌入表通常会变得非常庞大,一个表无法容纳在单个 GPU 上,这就需要多个设备和分布式环境。
在本节中,我们将通过 TorchRec 探索设置分布式环境、实际生产训练的进行方式以及嵌入表的分片。
本节也只使用 1 个 GPU,但它将以分布式方式进行处理。这仅是训练的限制,因为训练每个 GPU 都有一个进程。推理不会遇到此要求。
在下面的示例代码中,我们设置了 PyTorch 分布式环境。
警告
如果您在 Google Colab 中运行,则只能调用此单元格一次,再次调用将导致错误,因为您只能初始化一次进程组。
import os
import torch.distributed as dist
# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"
# nccl backend is for GPUs, gloo is for CPUs
dist.init_process_group(backend="gloo")
print(f"Distributed environment initialized: {dist}")
分布式嵌入式#
我们已经接触过主要的 TorchRec 模块:EmbeddingBagCollection。我们已经研究了它是如何工作的以及数据在 TorchRec 中是如何表示的。然而,我们尚未探索 TorchRec 的一个主要部分,那就是 分布式嵌入式。
GPU 是目前最受欢迎的机器学习工作负载选择,因为它们能够处理比 CPU 大几个数量级的浮点运算/秒(FLOPs)。然而,GPU 的缺点是快速内存(HBM,类似于 CPU 的 RAM)有限,通常只有几十 GB。
推荐系统模型可能包含远超单个 GPU 内存限制的嵌入表,因此需要将嵌入表分布到多个 GPU 上,也称为 模型并行。另一方面,数据并行 是指在每个 GPU 上复制整个模型,每个 GPU 负责处理不同的数据批次进行训练,并在反向传播时同步梯度。
计算量需求较低但内存需求较高(嵌入式)的模型部分采用模型并行进行分发,而 计算量需求较高但内存需求较低(密集层、MLP 等)的模型部分采用数据并行进行分发。
规划器(Planner)#
在我们展示分片如何工作之前,我们必须了解 规划器,它有助于我们确定最佳分片配置。
给定嵌入表的数量和 GPU 的数量,存在许多不同的分片配置。例如,给定 2 个嵌入表和 2 个 GPU,您可以:
将 1 个表放在每个 GPU 上
将两个表都放在一个 GPU 上,另一个 GPU 上不放任何表
将某些行和列放在每个 GPU 上
考虑到所有这些可能性,我们通常需要一个对性能最优的分片配置。
这就是规划器发挥作用的地方。规划器能够确定给定嵌入表的数量和 GPU 的数量,什么是最佳配置。事实证明,手动完成这项工作极其困难,工程师需要考虑许多因素才能确保最佳的分片计划。幸运的是,TorchRec 在使用规划器时提供了自动规划器。
TorchRec 规划器
评估硬件的内存限制
根据内存获取(如嵌入查找)估算计算量
处理特定于数据 的因素
考虑带宽等其他硬件特定信息,以生成模型的最佳分片计划
为了考虑所有这些变量,TorchRec 规划器可以接受 各种关于嵌入表、约束、硬件信息和拓扑的数据,以帮助生成模型的最佳分片计划,该计划通常在整个堆栈中提供。
要了解更多关于分片的信息,请参阅我们的 分片教程。
# In our case, 1 GPU and compute on CUDA device
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size=1,
compute_device="cuda",
)
)
# Run planner to get plan for sharding
plan = planner.collective_plan(ebc, [sharder], pg)
print(f"Sharding Plan generated: {plan}")
规划器结果#
如上所示,运行规划器时会产生大量输出。我们可以看到许多正在计算的统计数据以及我们的表最终放置的位置。
运行规划器的结果是一个静态计划,可以用于分片!这允许分片对于生产模型是静态的,而不是每次都确定新的分片计划。下面,我们使用分片计划最终生成我们的 ShardedEmbeddingBagCollection。
# The static plan that was generated
plan
env = ShardingEnv.from_process_group(pg)
# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))
print(f"Sharded EBC Module: {sharded_ebc}")
使用 LazyAwaitable 进行 GPU 训练#
请记住,TorchRec 是一个高度优化的分布式嵌入式库。TorchRec 引入的一个用于提高 GPU 训练性能的概念是 LazyAwaitable `。您将看到 LazyAwaitable 类型作为各种分片 TorchRec 模块的输出。 LazyAwaitable 类型所做的就是尽可能延迟计算某个结果,它通过充当异步类型来做到这一点。
from typing import List
from torchrec.distributed.types import LazyAwaitable
# Demonstrate a ``LazyAwaitable`` type:
class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
def __init__(self, size: List[int]) -> None:
super().__init__()
self._size = size
def _wait_impl(self) -> torch.Tensor:
return torch.ones(self._size)
awaitable = ExampleAwaitable([3, 2])
awaitable.wait()
kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
print(output)
kt = output.wait()
# Now we have our ``KeyedTensor`` after calling ``.wait()``
# If you are confused as to why we have a ``KeyedTensor ``output,
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
print(type(kt))
print(kt.keys())
print(kt.values().shape)
# Same output format as unsharded ``EmbeddingBagCollection``
result_dict = kt.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
分片 TorchRec 模块的结构#
我们现在已经成功地基于生成的分片计划对 EmbeddingBagCollection 进行了分片!分片模块具有 TorchRec 的通用 API,这些 API 抽象了多个 GPU 之间的分布式通信/计算。事实上,这些 API 针对训练和推理的性能进行了高度优化。以下是 TorchRec 提供的用于分布式训练/推理的三个通用 API:
input_dist: 处理将输入从 GPU 分发到 GPU。lookups: 使用 FBGEMM TBE 以优化、批量的方式执行实际的嵌入查找(稍后会详细介绍)。output_dist: 处理将输出从 GPU 分发到 GPU。
输入和输出的分发是通过 NCCL Collectives,即 All-to-All 来完成的,这是所有 GPU 相互之间发送和接收数据的地方。TorchRec 与 PyTorch 分布式接口进行通信,并为最终用户提供简洁的抽象,消除了对底层细节的担忧。
反向传播执行所有这些集合操作,但顺序相反,用于分发梯度。input_dist、lookup 和 output_dist 都依赖于分片方案。由于我们是按表分片,因此这些 API 是由 TwPooledEmbeddingSharding 构建的模块。
sharded_ebc
# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists
# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists
优化嵌入查找#
在执行一组嵌入表的查找时,一个简单的解决方案是迭代所有 nn.EmbeddingBags 并为每个表执行一次查找。这正是标准、未分片的 EmbeddingBagCollection 所做的。但是,虽然这个解决方案很简单,但速度非常慢。
FBGEMM 是一个提供 GPU 运算符(也称为内核)的库,这些运算符经过高度优化。其中一个运算符称为 Table Batched Embedding (TBE),它提供了两个主要的优化:
表批处理(Table batching),允许您使用一个内核调用来查找多个嵌入式。
优化器融合(Optimizer Fusion),允许模块根据规范的 PyTorch 优化器和参数自行更新。
ShardedEmbeddingBagCollection 使用 FBGEMM TBE 进行查找,而不是传统的 nn.EmbeddingBags,以实现优化的嵌入查找。
sharded_ebc._lookups
DistributedModelParallel#
我们现在已经完成了对单个 EmbeddingBagCollection 的分片!我们能够使用 EmbeddingBagCollectionSharder 并使用未分片的 EmbeddingBagCollection 来生成 ShardedEmbeddingBagCollection 模块。这个工作流程是可以的,但在实现模型并行时,通常使用 DistributedModelParallel (DMP) 作为标准接口。当使用 DMP 包装模型(在本例中为 ebc)时,将发生以下情况:
决定如何分片模型。DMP 将收集可用的分片器,并制定一个关于如何最佳分片嵌入表(例如,
EmbeddingBagCollection)的计划。实际分片模型。这包括在适当的设备上为每个嵌入表分配内存。
DMP 接收我们刚才尝试过的所有内容,例如静态分片计划、分片器列表等。然而,它也有一些不错的默认设置,可以无缝分片 TorchRec 模型。在这个玩具示例中,由于我们有两个嵌入表和一个 GPU,TorchRec 将两者都放在单个 GPU 上。
ebc
model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))
out = model(kjt)
out.wait()
model
from fbgemm_gpu.split_embedding_configs import EmbOptimType
添加优化器#
请记住,TorchRec 模块针对大规模分布式训练进行了高度优化。一个重要的优化与优化器有关。
TorchRec 模块提供了一个无缝的 API,用于融合反向传播和训练中的优化步骤,从而显著提高性能并减少内存使用,同时还可以对分配给不同模型参数的优化器进行粒度控制。
优化器类#
TorchRec 使用 CombinedOptimizer,它包含一组 KeyedOptimizers。 CombinedOptimizer 有效地简化了处理模型中各种子组的多个优化器的操作。 KeyedOptimizer 扩展了 torch.optim.Optimizer,并通过参数字典进行初始化,该字典暴露了参数。 EmbeddingBagCollection 中的每个 TBE 模块都有自己的 KeyedOptimizer,它们组合成一个 CombinedOptimizer。
TorchRec 中的融合优化器#
使用 DistributedModelParallel 时,优化器是融合的,这意味着优化器更新在反向传播中完成。这是 TorchRec 和 FBGEMM 中的一项优化,其中优化器嵌入式梯度不会被具体化并直接应用于参数。这带来了显著的内存节省,因为嵌入式梯度通常与参数本身的大小相同。
但是,您可以选择将优化器设置为 dense,这样就不会应用此优化,允许您检查嵌入式梯度或根据需要对其应用计算。在这种情况下,密集优化器将是您的 规范的 PyTorch 模型训练循环与优化器。
一旦通过 DistributedModelParallel 创建了优化器,您仍然需要管理不与 TorchRec 嵌入式模块关联的其他参数的优化器。要找到其他参数,请使用 in_backward_optimizer_filter(model.named_parameters())。像处理普通 Torch 优化器一样为这些参数应用优化器,并将此与 model.fused_optimizer 组合成一个 CombinedOptimizer,您可以在训练循环中使用它来执行 zero_grad 和 step。
向 EmbeddingBagCollection 添加优化器#
我们将通过两种方式执行此操作,这两种方式是等效的,但根据您的偏好提供选项:
通过分片器中的
fused_params传递优化器关键字参数。通过
apply_optimizer_in_backward,它将优化器参数转换为fused_params以传递给EmbeddingBagCollection或EmbeddingCollection中的TBE。
# Option 1: Passing optimizer kwargs through fused parameters
from torchrec.optim.optimizers import in_backward_optimizer_filter
# We initialize the sharder with
fused_params = {
"optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
"learning_rate": 0.02,
"eps": 0.002,
}
# Initialize sharder with ``fused_params``
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)
# We'll use same plan and unsharded EBC as before but this time with our new sharder
sharded_ebc_fused_params = sharder_with_fused_params.shard(
ebc, plan.plan[""], env, torch.device("cuda")
)
# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
print(
f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}"
)
print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")
import copy
from torch.distributed.optim import (
_apply_optimizer_in_backward as apply_optimizer_in_backward,
)
# Option 2: Applying optimizer through apply_optimizer_in_backward
# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it
# We can achieve the same result as we did in the previous
ebc_apply_opt = copy.deepcopy(ebc)
optimizer_kwargs = {"lr": 0.5}
for name, param in ebc_apply_opt.named_parameters():
print(f"{name=}")
apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)
sharded_ebc_apply_opt = sharder.shard(
ebc_apply_opt, plan.plan[""], env, torch.device("cuda")
)
# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
print(sharded_ebc_apply_opt.fused_optimizer)
print(type(sharded_ebc_apply_opt.fused_optimizer))
# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
# there are no other parameters that aren't associated with TorchRec
print("Non Fused Model Parameters:")
print(
dict(
in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())
).keys()
)
# Here we do a dummy backwards call and see that parameter updates for fused
# optimizers happen as a result of the backward pass
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
print(f"First Iteration Loss: {loss}")
loss.backward()
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
# We don't call an optimizer.step(), so for the loss to have changed here,
# that means that the gradients were somehow updated, which is what the
# fused optimizer automatically handles for us
print(f"Second Iteration Loss: {loss}")
结论#
在本教程中,您已经完成了分布式推荐系统模型的训练。如果您对推理感兴趣,TorchRec 仓库 包含一个关于如何在推理模式下运行 TorchRec 的完整示例。
有关更多信息,请参阅我们的 dlrm 示例,其中包括使用 用于个性化和推荐系统的深度学习推荐模型 中描述的方法在 Criteo 1TB 数据集上进行多节点训练。