LLMMaskedCategorical¶
- class torchrl.modules.LLMMaskedCategorical(logits: Tensor, mask: Tensor, ignore_index: int = - 100)[源代码]¶
LLM 优化的掩码分类分布。
此类通过以下方式为 LLM 训练提供了更节省内存的方法:1. 在 log_prob 计算中使用 ignore_index=-100(无掩码开销)2. 在采样操作中使用传统掩码
这对于词汇量大的情况特别有益,因为掩盖所有 logits 可能会占用大量内存。
- 参数:
logits (torch.Tensor) – 事件对数概率(未归一化),形状为 [B, T, C]。 - B:批次大小(可选) - T:序列长度 - C:词汇量大小(类别数)
mask (torch.Tensor) –
布尔掩码,指示有效位置/标记。 - 如果形状为 [*B, T]:位置级掩码。True 表示该位置有效(所有标记都允许)。 - 如果形状为 [*B, T, C]:标记级掩码。True 表示该标记在该位置有效。
警告
标记级掩码比位置级掩码占用更多内存。仅在需要掩盖标记时使用。
ignore_index (int, optional) – log_prob 计算中要忽略的索引。默认为 -100。
- 输入形状
- 使用场景
- 位置级掩码
>>> logits = torch.randn(2, 10, 50000) # [B=2, T=10, C=50000] >>> mask = torch.ones(2, 10, dtype=torch.bool) # [B, T] >>> mask[0, :5] = False # mask first 5 positions of first sequence >>> dist = LLMMaskedCategorical(logits=logits, mask=mask) >>> tokens = torch.randint(0, 50000, (2, 10)) # [B, T] >>> tokens[0, :5] = -100 # set masked positions to ignore_index >>> log_probs = dist.log_prob(tokens) >>> samples = dist.sample() # [B, T]
- 标记级掩码
>>> logits = torch.randn(2, 10, 50000) >>> mask = torch.ones(2, 10, 50000, dtype=torch.bool) # [B, T, C] >>> mask[0, :5, :1000] = False # mask first 1000 tokens for first 5 positions >>> dist = LLMMaskedCategorical(logits=logits, mask=mask) >>> tokens = torch.randint(0, 50000, (2, 10)) >>> # Optionally, set tokens at fully-masked positions to ignore_index >>> log_probs = dist.log_prob(tokens) >>> samples = dist.sample() # [B, T]
注意事项
对于 log_prob,tokens 的形状必须为 [B, T],并包含有效的标记索引(0 <= token < C),或者对于被掩盖/忽略的位置使用 ignore_index。
对于标记级掩码,如果某个位置的标记被掩盖,log_prob 将为该条目返回 -inf。
对于位置级掩码,如果某个位置被掩盖(ignore_index),log_prob 将为该条目返回 0.0(对于交叉熵损失是正确的)。
采样始终遵守掩码(被掩盖的标记/位置永远不会被采样)。
所有文档化的使用场景均包含在 test_distributions.py 中的测试中。
- log_prob(value: Tensor) Tensor [源代码]¶
使用 ignore_index 方法计算对数概率。
这很节省内存,因为它不需要掩盖 logits。value 张量应为被掩盖的位置使用 ignore_index。
- property masked_dist: Categorical¶
获取用于采样操作的掩码分布。
- property position_level_masking: bool¶
掩码是位置级的(True)还是标记级的(False)。
- rsample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor [源代码]¶
使用掩码 logits 进行重参数化采样。
- sample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor [源代码]¶
使用掩码 logits 从分布中采样。