表批量嵌入 (TBE) 推理模块¶
稳定版 API¶
- class fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen(embedding_specs: list[tuple[str, int, int, fbgemm_gpu.split_embedding_configs.SparseType, fbgemm_gpu.split_table_batched_embeddings_ops_common.EmbeddingLocation]], feature_table_map: list[int] | None = None, index_remapping: list[torch.Tensor] | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, weight_lists: list[tuple[torch.Tensor, torch.Tensor | None]] | None = None, pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, output_dtype: SparseType = SparseType.FP16, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, enforce_hbm: bool = False, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, row_alignment: int | None = None, fp8_exponent_bits: int | None = None, fp8_exponent_bias: int | None = None, cache_assoc: int = 32, scale_bias_size_in_bytes: int = 4, cacheline_alignment: bool = True, uvm_host_mapped: bool = False, reverse_qparam: bool = False, feature_names_per_table: list[list[str]] | None = None, indices_dtype: dtype = torch.int32)¶
支持 FP32/FP16/FP8/INT8/INT4/INT2 权重的 nn.EmbeddingBag(sparse=False) 推理版本的表批量版本
- 参数:
embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –
嵌入规格列表。每个规格描述了一个物理嵌入表的规格。每个规格都是一个元组,包含嵌入行数、嵌入维度(必须是 4 的倍数)、表位置(EmbeddingLocation)和计算设备(ComputeDevice)。
可用的 EmbeddingLocation 选项是
DEVICE = 将嵌入表放置在 GPU 全局内存 (HBM) 中
MANAGED = 将嵌入放置在统一虚拟内存中(GPU 和 CPU 均可访问)
MANAGED_CACHING = 将嵌入表放置在统一虚拟内存中,并使用 GPU 全局内存 (HBM) 作为缓存
HOST = 将嵌入表放置在 CPU 内存 (DRAM) 中
MTIA = 将嵌入表放置在 MTIA 内存中
可用的 ComputeDevice 选项是
CPU = 在 CPU 上执行表查找
CUDA = 在 GPU 上执行表查找
MTIA = 在 MTIA 上执行表查找
feature_table_map (Optional[List[int]] = None) – 一个可选列表,指定特征表映射。feature_table_map[i] 指示特征 i 映射到的物理嵌入表。
index_remapping (Optional[List[Tensor]] = None) – 用于剪枝的索引重映射
pooling_mode (PoolingMode = PoolingMode.SUM) –
池化模式。可用的 PoolingMode 选项是
SUM = 求和池化
MEAN = 平均池化
NONE = 无池化(序列嵌入)
device (Optional[Union[str, int, torch.device]] = None) – 要放置张量的当前设备
bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –
输入检查模式。可用的 BoundsCheckMode 选项是
NONE = 跳过边界检查
FATAL = 遇到无效索引/偏移时抛出错误
WARNING = 遇到无效索引/偏移时打印警告消息并修复(将无效索引设置为零,并将无效偏移调整到边界内)
IGNORE = 静默修复无效索引/偏移(将无效索引设置为零,并将无效偏移调整到边界内)
weight_lists (Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None) – [T]
pruning_hash_load_factor (float = 0.5) – 剪枝哈希的负载因子
use_array_for_index_remapping (bool = True) – 如果为 True,则使用数组进行索引重映射。否则,使用哈希映射。
output_dtype (SparseType = SparseType.FP16) – 输出张量的数据类型。
cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –
缓存算法(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。选项包括
LRU = 最近最少使用
LFU = 最不常使用
cache_load_factor (float = 0.2) – 在使用 EmbeddingLocation.MANAGED_CACHING 时用于确定缓存容量的因子。缓存容量为 cache_load_factor * 所有嵌入表的总行数
cache_sets (int = 0) – 缓存集数(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)
cache_reserved_memory (float = 0.0) – 为非缓存目的在 HBM 中保留的内存量(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。
enforce_hbm (bool = False) – 如果为 True,在使用 EmbeddingLocation.MANAGED_CACHING 时将所有权重/动量放置在 HBM 中
record_cache_metrics (Optional[RecordCacheMetrics] = None) – 如果 RecordCacheMetrics.record_cache_miss_counter 为 True,则记录命中次数、请求次数等,如果 RecordCacheMetrics.record_tablewise_cache_miss 为 True,则记录类似指标按表统计
gather_uvm_cache_stats (Optional[bool] = False) – 如果为 True,在 EmbeddingLocation 设置为 MANAGED_CACHING 时收集缓存统计信息
row_alignment (Optional[int] = None) – 行对齐
fp8_exponent_bits (Optional[int] = None) – 使用 FP8 时的指数位数
fp8_exponent_bias (Optional[int] = None) – 使用 FP8 时的指数偏置
cache_assoc (int = 32) – 缓存的组数
scale_bias_size_in_bytes (int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES) – 缩放和偏置的大小(字节)
cacheline_alignment (bool = True) – 如果为 True,则将每个表对齐到 128b 缓存行边界
uvm_host_mapped (bool = False) – 如果为 True,则使用 malloc + cudaHostRegister 分配每个 UVM 张量。否则,使用 cudaMallocManaged
reverse_qparam (bool = False) – 如果为 True,则在每行的末尾加载 qparams。否则,在每行的开头加载 qparams。
feature_names_per_table (Optional[List[List[str]]] = None) – 一个可选列表,指定每个表的特征名称。feature_names_per_table[t] 表示表 t 的特征名称。
indices_dtype (torch.dtype = torch.int32) – 将传递给 forward() 调用的索引张量的预期数据类型。此信息将用于构造 remap_indices 数组/哈希。选项包括 torch.int32 和 torch.int64。
- assign_embedding_weights(q_weight_list: list[tuple[torch.Tensor, torch.Tensor | None]]) None¶
使用来自权重和缩放-偏移量列表的值分配 self.split_embedding_weights()。
- forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None) Tensor¶
定义每次调用时执行的计算。
所有子类都应重写此方法。
注意
尽管 forward pass 的实现需要在此函数中定义,但你应该在之后调用
Module实例,而不是这个函数,因为前者会处理已注册的钩子,而后者会静默忽略它们。
- recompute_module_buffers() None¶
计算位于 meta 设备上且在 reset_weights_placements_and_offsets() 中未具体化的模块缓冲区。目前这些缓冲区是 weights_tys、rows_per_table、D_offsets 和 bounds_check_warning。剪枝相关或 uvm 相关的缓冲区目前不计算。
- split_embedding_weights(split_scale_shifts: bool = True) list[tuple[torch.Tensor, torch.Tensor | None]]¶
返回按表拆分的权重列表
- split_embedding_weights_with_scale_bias(split_scale_bias_mode: int = 1) list[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]]¶
返回按表拆分的权重列表,split_scale_bias_mode
0: 返回一行;1: 返回权重 + 缩放_偏移;2: 返回权重、缩放、偏移。