快捷方式

设置 TorchRec

在本节中,我们将

  • 了解使用 TorchRec 的要求

  • 设置用于集成 TorchRec 的环境

  • 运行基本的 TorchRec 代码

系统要求

TorchRec 通常在 AWS Linux 上进行测试,并且应该在类似的环境中运行。下面展示了目前测试的兼容性矩阵

Python 版本

3.9, 3.10, 3.11, 3.12

计算平台

CPU、CUDA 11.8、CUDA 12.1、CUDA 12.4

除了这些要求之外,TorchRec 的核心依赖是 PyTorch 和 FBGEMM。如果您的系统通常与这两个库兼容,那么它应该足以支持 TorchRec。

版本兼容性

TorchRec 和 FBGEMM 在发布时具有匹配的版本号,它们会一起进行测试

  • TorchRec 1.0 与 FBGEMM 1.0 兼容

  • TorchRec 0.8 与 FBGEMM 0.8 兼容

  • TorchRec 0.8 可能与 FBGEMM 0.7 不兼容

此外,TorchRec 和 FBGEMM 仅在新的 PyTorch 版本发布时一起发布。因此,特定版本的 TorchRec 和 FBGEMM 应对应特定的 PyTorch 版本

  • TorchRec 1.0 与 PyTorch 2.5 兼容

  • TorchRec 0.8 与 PyTorch 2.4 兼容

  • TorchRec 0.8 可能与 PyTorch 2.3 不兼容

安装

下面以 CUDA 12.1 的安装为例。对于 CPU、CUDA 11.8 或 CUDA 12.4,请将 cu121 分别替换为 cpucu118cu124

pip install torch --index-url https://download.pytorch.org/whl/cu121
pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu121
pip install torchmetrics==1.0.3
pip install torchrec --index-url https://download.pytorch.org/whl/cu121
pip install torch
pip install fbgemm-gpu
pip install torchrec
pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu121
pip install torchmetrics==1.0.3
pip install torchrec --index-url https://download.pytorch.org/whl/nightly/cu121

您也可以从源代码构建 TorchRec,以开发 TorchRec 的最新更改。要从源代码构建,请参阅此 参考

运行一个简单的 TorchRec 示例

现在我们已经正确设置了 TorchRec,让我们运行一些 TorchRec 代码!下面,我们将使用 TorchRec 数据类型 KeyedJaggedTensorEmbeddingBagCollection 进行简单的正向传播。

import torch

import torchrec
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor

ebc = torchrec.EmbeddingBagCollection(
    device="cpu",
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=16,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=torchrec.PoolingType.SUM,
        ),
        torchrec.EmbeddingBagConfig(
            name="user_table",
            embedding_dim=16,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=torchrec.PoolingType.SUM,
        )
    ]
)

product_jt = JaggedTensor(
    values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))

# Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})

print("Call EmbeddingBagCollection Forward: ", ebc(kjt))

将以上代码保存到名为 torchrec_example.py 的文件中。然后,您应该能够从终端执行它:

python torchrec_example.py

您应该会看到 KeyedTensor 的输出,其中包含结果嵌入。恭喜!您已成功安装并运行了您的第一个 TorchRec 程序!

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源