评价此页

TorchRec 简介#

创建日期:2024年10月2日 | 最后更新:2025年7月10日 | 最后验证:2024年10月2日

TorchRec 是一个专为构建基于嵌入(embeddings)的可扩展且高效的推荐系统而设计的 PyTorch 库。本教程将指导您完成安装过程,介绍嵌入的概念,并强调它们在推荐系统中的重要性。此外,还将演示如何在 PyTorch 和 TorchRec 中实现嵌入,重点介绍如何通过分布式训练和高级优化处理大规模嵌入表。

您将学到什么
  • 嵌入的基本原理及其在推荐系统中的作用

  • 如何设置 TorchRec 以在 PyTorch 环境中管理和实现嵌入

  • 探索将大规模嵌入表分布在多个 GPU 上的高级技术

先决条件
  • PyTorch v2.5 或更高版本,配合 CUDA 11.8 或更高版本

  • Python 3.9 或更高版本

  • FBGEMM

安装依赖项#

在 Google Colab 中运行本教程之前,请确保安装以下依赖项

!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
!pip3 install torchmetrics==1.0.3
!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121

注意

如果您在 Google Colab 中运行,请确保将运行时类型切换为 GPU。更多信息,请参阅 启用 CUDA

嵌入 (Embeddings)#

在构建推荐系统时,类别特征通常具有极高的基数(cardinality),例如文章、用户、广告等。

为了表示这些实体并建模它们之间的关系,通常会使用嵌入(Embeddings)。在机器学习中,嵌入是高维空间中的实数向量,用于表示诸如单词、图像或用户等复杂数据的含义

推荐系统中的嵌入#

您可能想知道,这些嵌入最初是如何生成的?实际上,嵌入表现为嵌入表(Embedding Table)中的独立行,也称为嵌入权重。这是因为嵌入或嵌入表权重与模型的所有其他权重一样,都是通过梯度下降进行训练的!

嵌入表本质上是一个存储嵌入的大矩阵,具有两个维度 (B, N),其中:

  • B 是表中存储的嵌入数量

  • N 是每个嵌入的维度(N 维嵌入)。

嵌入表的输入表示嵌入查找,用于检索特定索引或行的嵌入。在推荐系统中,例如许多大型系统所使用的那样,唯一 ID 不仅用于特定用户,也用于文章和广告等实体,作为查找各自嵌入表的索引!

在推荐系统中,嵌入通过以下过程进行训练:

  • 输入/查找索引作为唯一 ID 被输入模型。ID 会被哈希到嵌入表的总大小,以防止 ID 大于行数时出现问题。

  • 随后检索并对嵌入进行池化(Pooling),例如取嵌入的总和或平均值。这是必需的,因为每个样本的嵌入数量可能是可变的,而模型需要一致的形状。

  • 嵌入与模型其余部分结合使用以产生预测,例如广告的 点击率 (CTR)

  • 通过预测值和样本标签计算损失,模型的所有权重都会通过梯度下降和反向传播进行更新,包括与该样本关联的嵌入权重

这些嵌入对于表示用户、文章和广告等类别特征至关重要,以便捕获关系并做出准确的推荐。深度学习推荐模型 (DLRM) 论文更深入地讨论了在推荐系统中使用嵌入表的技术细节。

本教程介绍了嵌入的概念,展示了 TorchRec 特有的模块和数据类型,并描述了 TorchRec 分布式训练的工作原理。

import torch

PyTorch 中的嵌入#

在 PyTorch 中,我们有以下类型的嵌入:

  • torch.nn.Embedding:一种嵌入表,其前向传播直接返回嵌入本身。

  • torch.nn.EmbeddingBag:一种嵌入表,其前向传播返回池化(例如求和或平均)后的嵌入,也称为池化嵌入(Pooled Embeddings)

在本节中,我们将简要介绍通过向表传递索引来执行嵌入查找的操作。

num_embeddings, embedding_dim = 10, 4

# Initialize our embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights:", weights)

# Pass in pre-generated weights just for example, typically weights are randomly initialized
embedding_collection = torch.nn.Embedding(
    num_embeddings, embedding_dim, _weight=weights
)
embedding_bag_collection = torch.nn.EmbeddingBag(
    num_embeddings, embedding_dim, _weight=weights
)

# Print out the tables, we should see the same weights as above
print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)

# Lookup rows (ids for embedding ids) from the embedding tables
# 2D tensor with shape (batch_size, ids for each batch)
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)

embeddings = embedding_collection(ids)

# Print out the embedding lookups
# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)

# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
pooled_embeddings = embedding_bag_collection(ids)

print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)

# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))

恭喜!现在您已基本了解如何使用嵌入表——这是现代推荐系统的基石之一!这些表代表了实体及其关系。例如,特定用户与他们喜欢的页面和文章之间的关系。

TorchRec 功能概览#

在上一节中,我们学习了如何使用嵌入表,这是现代推荐系统的基石之一!这些表代表实体和关系(如用户、页面、文章等)。鉴于这些实体不断增加,通常会应用哈希函数来确保 ID 处于特定嵌入表的边界内。然而,为了表示海量实体并减少哈希冲突,这些表可能会变得非常巨大(例如,试想广告的数量)。事实上,这些表可能会大到连 80GB 内存的单张 GPU 都无法容纳。

为了训练具有海量嵌入表的模型,需要将这些表分片(sharding)到多个 GPU 上,这带来了并行和优化方面的一系列新问题和机遇。幸运的是,我们有 TorchRec 库,它已经遇到、整合并解决了许多这些问题。TorchRec 是一个提供大规模分布式嵌入原语的库

接下来,我们将探讨 TorchRec 库的主要功能。我们将从 torch.nn.Embedding 开始,将其扩展到自定义的 TorchRec 模块,探索生成嵌入分片计划的分布式训练环境,了解 TorchRec 的内置优化,并扩展模型以支持 C++ 推理。以下是本节内容的快速大纲:

  • TorchRec 模块和数据类型

  • 分布式训练、分片和优化

让我们从导入 TorchRec 开始

import torchrec

本节介绍 TorchRec 模块和数据类型,包括 EmbeddingCollectionEmbeddingBagCollectionJaggedTensorKeyedJaggedTensorKeyedTensor 等。

EmbeddingBagEmbeddingBagCollection#

我们已经探索了 torch.nn.Embeddingtorch.nn.EmbeddingBag。TorchRec 通过创建嵌入集合来扩展这些模块,即可以使用多个嵌入表的模块,即 EmbeddingCollectionEmbeddingBagCollection。我们将使用 EmbeddingBagCollection 来表示一组嵌入袋(Embedding Bags)。

在下面的示例代码中,我们创建了一个 EmbeddingBagCollection (EBC),其中包含两个嵌入袋,一个代表产品,另一个代表用户。每个表(product_tableuser_table)均由 4096 个维度为 64 的嵌入表示。

ebc = torchrec.EmbeddingBagCollection(
    device="cpu",
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=torchrec.PoolingType.SUM,
        ),
        torchrec.EmbeddingBagConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=torchrec.PoolingType.SUM,
        ),
    ],
)
print(ebc.embedding_bags)

让我们检查 EmbeddingBagCollection 的前向传播方法以及该模块的输入和输出。

import inspect

# Let's look at the ``EmbeddingBagCollection`` forward method
# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
print(inspect.getsource(ebc.forward))

TorchRec 输入/输出数据类型#

TorchRec 为其模块的输入和输出提供了独特的数据类型:JaggedTensorKeyedJaggedTensorKeyedTensor。您可能会问,为什么要创建新的数据类型来表示稀疏特征?要回答这个问题,我们必须理解代码中是如何表示稀疏特征的。

稀疏特征也称为 id_list_featureid_score_list_feature,它们是用作查找表索引的 ID,用于检索相应 ID 的嵌入。举个简单的例子,想象一个稀疏特征是用户互动过的广告。输入本身是一组用户互动过的广告 ID,而检索到的嵌入就是这些广告的语义表示。在代码中表示这些特征的难点在于,在每个输入样本中,ID 的数量是可变的。用户今天可能只与一个广告互动,而明天可能与三个广告互动。

如下所示是一个简单的表示,其中 lengths 张量表示批次中每个样本的索引数量,values 张量包含索引本身。

# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])

# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])

接下来,让我们看看偏移量(offsets)以及每个批次中包含的内容。

# Lengths can be converted to offsets for easy indexing of values
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)

print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print(
    "Second Batch: ",
    id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
)

from torchrec import JaggedTensor

# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)

# Automatically compute offsets from lengths
print("Offsets: ", jt.offsets())

# Convert to list of values
print("List of Values: ", jt.to_dense())

# ``__str__`` representation
print(jt)

from torchrec import KeyedJaggedTensor

# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!

product_jt = JaggedTensor(
    values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))

# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})

# Look at our feature keys for the ``KeyedJaggedTensor``
print("Keys: ", kjt.keys())

# Look at the overall lengths for the ``KeyedJaggedTensor``
print("Lengths: ", kjt.lengths())

# Look at all values for ``KeyedJaggedTensor``
print("Values: ", kjt.values())

# Can convert ``KeyedJaggedTensor`` to dictionary representation
print("to_dict: ", kjt.to_dict())

# ``KeyedJaggedTensor`` string representation
print(kjt)

# Q2: What are the offsets for the ``KeyedJaggedTensor``?

# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
result = ebc(kjt)
result

# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
print(result.keys())

# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
# meaning embeddings for both features are of size 64. 64 + 64 = 128
print(result.values().shape)

# Nice to_dict method to determine the embeddings that belong to each feature
result_dict = result.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)

恭喜!现在您已经了解了 TorchRec 模块和数据类型。为自己坚持到这里点个赞吧。接下来,我们将学习分布式训练和分片。

分布式训练和分片#

既然我们已经掌握了 TorchRec 模块和数据类型,是时候更进一步了。

请记住,TorchRec 的主要目的是为分布式嵌入提供原语。到目前为止,我们只在单个设备上处理嵌入表。这在嵌入表很小时是可行的,但在生产环境中通常并非如此。嵌入表往往变得非常巨大,以至于一个表无法放在单个 GPU 上,这就产生了对多个设备和分布式环境的需求。

在本节中,我们将探索如何利用 TorchRec 设置分布式环境,了解实际的生产训练是如何完成的,并探索嵌入表的分片。

本节同样只会使用 1 个 GPU,尽管它将以分布式方式处理。这只是训练方面的限制,因为训练是每个 GPU 一个进程。推理阶段则没有此要求。

在下面的示例代码中,我们设置了 PyTorch 分布式环境。

警告

如果您在 Google Colab 中运行此代码,则只能调用此单元格一次。再次调用将导致错误,因为进程组只能初始化一次。

import os

import torch.distributed as dist

# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"

# nccl backend is for GPUs, gloo is for CPUs
dist.init_process_group(backend="gloo")

print(f"Distributed environment initialized: {dist}")

分布式嵌入#

我们已经使用过主要的 TorchRec 模块:EmbeddingBagCollection。我们研究了它的工作原理以及 TorchRec 中数据的表示方式。然而,我们还没有探讨 TorchRec 的主要部分之一,即分布式嵌入

迄今为止,GPU 是机器学习工作负载最受欢迎的选择,因为它们每秒能执行的浮点运算次数(FLOPs)比 CPU 高出几个数量级。然而,GPU 受到可用快速内存(HBM,类似于 CPU 的 RAM)稀缺的限制,通常仅有几十 GB。

推荐系统模型可能包含远超单张 GPU 内存限制的嵌入表,因此需要将嵌入表分布在多个 GPU 上,即模型并行(Model Parallel)。另一方面,数据并行(Data Parallel)是指将整个模型复制到每个 GPU 上,每个 GPU 获取不同的数据批次进行训练,并在反向传播时同步梯度。

模型中计算需求较少但内存需求较大(嵌入)的部分使用模型并行进行分布,而计算需求较大但内存需求较小(密集层、MLP 等)的部分则使用数据并行进行分布

分片 (Sharding)#

为了分布嵌入表,我们将嵌入表拆分为多个部分并将其放置在不同的设备上,这被称为“分片”。

分片嵌入表的方法有很多种。最常见的方法是:

  • 表级分片 (Table-Wise):整个表放置在一个设备上

  • 列级分片 (Column-Wise):嵌入表的列被分片

  • 行级分片 (Row-Wise):嵌入表的行被分片

分片模块 (Sharded Modules)#

虽然这一切看起来很难处理和实现,但您很幸运。TorchRec 提供了所有原语以实现简便的分布式训练和推理!实际上,TorchRec 模块有两个对应的类,用于在分布式环境中使用任何 TorchRec 模块:

  • 模块分片器 (Module Sharder):此类公开了一个 shard API,用于处理 TorchRec 模块的分片,从而生成分片模块。* 对于 EmbeddingBagCollection,其分片器是 EmbeddingBagCollectionSharder

  • 分片模块 (Sharded Module):此类是 TorchRec 模块的分片变体。它具有与常规 TorchRec 模块相同的输入/输出,但经过了大量优化,可在分布式环境中工作。* 对于 EmbeddingBagCollection,其分片变体是 ShardedEmbeddingBagCollection

每个 TorchRec 模块都有未分片(unsharded)和分片(sharded)两个变体。

  • 未分片版本用于原型设计和实验。

  • 分片版本用于分布式环境中的训练和推理。

TorchRec 模块的分片版本(例如 EmbeddingBagCollection)将处理模型并行所需的一切,例如在 GPU 之间进行通信,将嵌入分配到正确的 GPU 上。

回顾我们的 EmbeddingBagCollection 模块

ebc

from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingEnv

# Corresponding sharder for ``EmbeddingBagCollection`` module
sharder = EmbeddingBagCollectionSharder()

# ``ProcessGroup`` from torch.distributed initialized 2 cells above
pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"

print(f"Process Group: {pg}")

规划器 (Planner)#

在我们展示分片如何工作之前,必须先了解规划器 (planner),它能帮助我们确定最佳的分片配置。

给定一定数量的嵌入表和 rank(GPU 数量),存在多种不同的分片配置。例如,给定 2 个嵌入表和 2 个 GPU,您可以:

  • 每个 GPU 放置 1 个表

  • 将两个表放在一个 GPU 上,另一个 GPU 不放任何表

  • 在每个 GPU 上放置特定的行和列

考虑到所有这些可能性,我们通常希望得到一种在性能上最优的分片配置。

这就是规划器的用武之地。规划器能够根据嵌入表的数量和 GPU 的数量确定最佳配置。事实上,手动执行此操作非常困难,工程师必须考虑大量因素才能确保分片计划的最优。幸运的是,TorchRec 在使用规划器时提供了自动规划功能。

TorchRec 规划器

  • 评估硬件的内存限制

  • 根据嵌入查找的内存提取量来估计计算量

  • 处理数据特定的因素

  • 考虑带宽等其他硬件特性,以生成最优的分片计划

为了考量所有这些变量,TorchRec 规划器可以接收 大量关于嵌入表的数据、约束、硬件信息和拓扑结构,以帮助为模型生成最优分片计划,这些在各个栈中均有提供。

要了解有关分片的更多信息,请参阅我们的 分片教程

# In our case, 1 GPU and compute on CUDA device
planner = EmbeddingShardingPlanner(
    topology=Topology(
        world_size=1,
        compute_device="cuda",
    )
)

# Run planner to get plan for sharding
plan = planner.collective_plan(ebc, [sharder], pg)

print(f"Sharding Plan generated: {plan}")

规划器结果#

如上所示,运行规划器时会有大量输出。我们可以看到正在计算的许多统计数据,以及我们的表最终被放置的位置。

运行规划器的结果是一个静态计划,可以重复用于分片!这使得生产模型的分片保持静态,而不是每次都确定新的分片计划。下面,我们使用分片计划最终生成我们的 ShardedEmbeddingBagCollection

# The static plan that was generated
plan

env = ShardingEnv.from_process_group(pg)

# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))

print(f"Sharded EBC Module: {sharded_ebc}")

使用 LazyAwaitable 进行 GPU 训练#

请记住,TorchRec 是一个针对分布式嵌入高度优化的库。TorchRec 为在 GPU 上实现更高训练性能引入的一个概念是 LazyAwaitable。您会看到 LazyAwaitable 类型作为各种分片 TorchRec 模块的输出。 LazyAwaitable 类型所做的仅仅是尽可能推迟计算某些结果,并像异步类型一样工作。

from typing import List

from torchrec.distributed.types import LazyAwaitable


# Demonstrate a ``LazyAwaitable`` type:
class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
    def __init__(self, size: List[int]) -> None:
        super().__init__()
        self._size = size

    def _wait_impl(self) -> torch.Tensor:
        return torch.ones(self._size)


awaitable = ExampleAwaitable([3, 2])
awaitable.wait()

kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
print(output)

kt = output.wait()
# Now we have our ``KeyedTensor`` after calling ``.wait()``
# If you are confused as to why we have a ``KeyedTensor ``output,
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
print(type(kt))

print(kt.keys())

print(kt.values().shape)

# Same output format as unsharded ``EmbeddingBagCollection``
result_dict = kt.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)

分片 TorchRec 模块的剖析#

现在,我们已经根据生成的分片计划成功分片了一个 EmbeddingBagCollection!该分片模块拥有 TorchRec 的常用 API,这些 API 抽象了多个 GPU 之间的分布式通信/计算。事实上,这些 API 对训练和推理的性能进行了高度优化。以下是 TorchRec 提供的三个用于分布式训练/推理的常用 API:

  • input_dist:处理从 GPU 到 GPU 的输入分发。

  • lookups:使用 FBGEMM TBE(后续会详细介绍)以优化的批处理方式执行实际的嵌入查找。

  • output_dist:处理从 GPU 到 GPU 的输出分发。

输入和输出的分发是通过 NCCL Collectives(即 All-to-Alls)完成的,这是所有 GPU 相互发送和接收数据的方式。TorchRec 与 PyTorch 分布式接口对接以进行集体通信,并为最终用户提供清晰的抽象,无需关心底层细节。

反向传播会以相反的顺序执行所有这些集体通信,以分发梯度。input_distlookupoutput_dist 都取决于分片方案。由于我们采用了表级分片,这些 API 是由 TwPooledEmbeddingSharding 构建的模块。

sharded_ebc

# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists

# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists

优化嵌入查找#

在为一组嵌入表执行查找时,一个简单的解决方案是遍历所有 nn.EmbeddingBags 并为每个表进行查找。这正是标准未分片 EmbeddingBagCollection 所做的。然而,虽然这个解决方案很简单,但速度极慢。

FBGEMM 是一个提供高度优化的 GPU 算子(也称为内核)的库。其中一个算子称为表批处理嵌入 (Table Batched Embedding, TBE),它提供了两个主要优化:

  • 表批处理,允许您通过一次内核调用查找多个嵌入。

  • 优化器融合 (Optimizer Fusion),允许模块根据规范的 PyTorch 优化器和参数进行自我更新。

ShardedEmbeddingBagCollection 使用 FBGEMM TBE 作为查找方式,而不是传统的 nn.EmbeddingBags,以实现优化的嵌入查找。

sharded_ebc._lookups

DistributedModelParallel#

我们现在已经探索了分片单个 EmbeddingBagCollection!我们成功地使用 EmbeddingBagCollectionSharder 和未分片的 EmbeddingBagCollection 生成了一个 ShardedEmbeddingBagCollection 模块。这个工作流程没问题,但通常在实现模型并行时,DistributedModelParallel (DMP) 是标准接口。当使用 DMP 包装模型(本例中为 ebc)时,会发生以下情况:

  1. 决定如何分片模型。DMP 将收集可用的分片器并制定出分片嵌入表(例如 EmbeddingBagCollection)的最佳方案。

  2. 实际分片模型。这包括在相应设备上为每个嵌入表分配内存。

DMP 接收我们刚刚试验过的所有内容,如静态分片计划、分片器列表等。此外,它还提供了一些不错的默认设置,可以无缝分片 TorchRec 模型。在这个演示示例中,由于我们有两个嵌入表和一个 GPU,TorchRec 会将两者都放置在单个 GPU 上。

ebc

model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))

out = model(kjt)
out.wait()

model


from fbgemm_gpu.split_embedding_configs import EmbOptimType

分片最佳实践#

目前,我们的配置仅在 1 个 GPU(或 rank)上进行分片,这很简单:只需将所有表放在 1 个 GPU 的内存中即可。然而,在真实的生产用例中,嵌入表通常分片在数百个 GPU 上,并使用表级、行级和列级等不同的分片方法。确定适当的分片配置(以防止内存不足问题)同时保持内存和计算的平衡对于获得最佳性能至关重要。

添加优化器#

请记住,TorchRec 模块针对大规模分布式训练进行了高度优化。一个重要的优化与优化器有关。

TorchRec 模块提供了一个无缝 API,用于在训练中融合反向传播和优化步骤,从而在性能方面实现显著优化并减少内存使用,同时支持为不同的模型参数分配不同的优化器。

优化器类#

TorchRec 使用 CombinedOptimizer,其中包含 KeyedOptimizers 的集合。CombinedOptimizer 有效地简化了为模型中不同子组处理多个优化器的操作。KeyedOptimizer 扩展了 torch.optim.Optimizer,并通过参数字典进行初始化。每个 EmbeddingBagCollection 中的 TBE 模块都将拥有自己的 KeyedOptimizer,它们会合并为一个 CombinedOptimizer

TorchRec 中的融合优化器#

使用 DistributedModelParallel优化器是融合的,这意味着优化器更新是在反向传播中完成的。这是 TorchRec 和 FBGEMM 的一项优化,其中嵌入梯度不会被具体化,而是直接应用于参数。由于嵌入梯度的大小通常与参数本身相当,这带来了显著的内存节省。

不过,您可以选择使优化器成为 dense(密集)类型,这不会应用此优化,允许您检查嵌入梯度或根据需要对其应用计算。在这种情况下,密集优化器就是您 规范的 PyTorch 模型训练循环(带有 optimizer.step())

一旦通过 DistributedModelParallel 创建了优化器,您仍然需要为不与 TorchRec 嵌入模块关联的其他参数管理一个优化器。要找到这些参数,请使用 in_backward_optimizer_filter(model.named_parameters())。像对待普通的 PyTorch 优化器一样将优化器应用于这些参数,并将此优化器与 model.fused_optimizer 合并为一个 CombinedOptimizer,您可以在训练循环中使用它来调用 zero_gradstep

EmbeddingBagCollection 添加优化器#

我们将通过两种方式执行此操作,它们是等效的,但您可以根据偏好进行选择:

  1. 通过分片器中的 fused_params 传递优化器 kwargs。

  2. 通过 apply_optimizer_in_backward,它将优化器参数转换为 fused_params,以便传递给 EmbeddingBagCollectionEmbeddingCollection 中的 TBE

# Option 1: Passing optimizer kwargs through fused parameters
from torchrec.optim.optimizers import in_backward_optimizer_filter


# We initialize the sharder with
fused_params = {
    "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
    "learning_rate": 0.02,
    "eps": 0.002,
}

# Initialize sharder with ``fused_params``
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)

# We'll use same plan and unsharded EBC as before but this time with our new sharder
sharded_ebc_fused_params = sharder_with_fused_params.shard(
    ebc, plan.plan[""], env, torch.device("cuda")
)

# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
print(
    f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}"
)

print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")

import copy

from torch.distributed.optim import (
    _apply_optimizer_in_backward as apply_optimizer_in_backward,
)

# Option 2: Applying optimizer through apply_optimizer_in_backward
# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it

# We can achieve the same result as we did in the previous
ebc_apply_opt = copy.deepcopy(ebc)
optimizer_kwargs = {"lr": 0.5}

for name, param in ebc_apply_opt.named_parameters():
    print(f"{name=}")
    apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)

sharded_ebc_apply_opt = sharder.shard(
    ebc_apply_opt, plan.plan[""], env, torch.device("cuda")
)

# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
print(sharded_ebc_apply_opt.fused_optimizer)
print(type(sharded_ebc_apply_opt.fused_optimizer))

# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
# there are no other parameters that aren't associated with TorchRec
print("Non Fused Model Parameters:")
print(
    dict(
        in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())
    ).keys()
)

# Here we do a dummy backwards call and see that parameter updates for fused
# optimizers happen as a result of the backward pass

ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
print(f"First Iteration Loss: {loss}")

loss.backward()

ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
# We don't call an optimizer.step(), so for the loss to have changed here,
# that means that the gradients were somehow updated, which is what the
# fused optimizer automatically handles for us
print(f"Second Iteration Loss: {loss}")

结论#

在本教程中,您已经完成了分布式推荐系统模型的训练。如果您对推理感兴趣,TorchRec 仓库 中提供了关于如何以推理模式运行 TorchRec 的完整示例。

更多信息,请参阅我们的 dlrm 示例,其中包含使用 《深度学习推荐模型(用于个性化和推荐系统)》 中描述的方法在 Criteo 1TB 数据集上进行多节点训练的示例。