LLMMaskedCategorical¶
- class torchrl.modules.LLMMaskedCategorical(logits: Tensor, mask: Tensor, ignore_index: int = - 100)[源代码]¶
LLM 优化的掩码分类分布。
此类通过以下方式为 LLM 训练提供更内存高效的方法: 1. 使用 ignore_index=-100 进行 log_prob 计算(无掩码开销) 2. 使用传统掩码进行采样操作
这对于掩码所有 logits 会占用大量内存的大词汇量特别有利。
- 参数:
logits (torch.Tensor) – 事件的对数概率(未归一化),形状为 [B, T, C]。 - B:批次大小(可选) - T:序列长度 - C:词汇量大小(类别数量)
mask (torch.Tensor) –
布尔掩码,指示有效位置/标记。 - 如果形状为 [*B, T]:位置级掩码。True 表示位置有效(所有标记都允许)。 - 如果形状为 [*B, T, C]:标记级掩码。True 表示在该位置标记有效。
警告
标记级掩码比位置级掩码占用更多内存。只有在需要掩码标记时才使用此选项。
ignore_index (int, optional) – log_prob 计算中要忽略的索引。默认为 -100。
- 输入形状
- 用例
- 位置级掩码
>>> logits = torch.randn(2, 10, 50000) # [B=2, T=10, C=50000] >>> mask = torch.ones(2, 10, dtype=torch.bool) # [B, T] >>> mask[0, :5] = False # mask first 5 positions of first sequence >>> dist = LLMMaskedCategorical(logits=logits, mask=mask) >>> tokens = torch.randint(0, 50000, (2, 10)) # [B, T] >>> tokens[0, :5] = -100 # set masked positions to ignore_index >>> log_probs = dist.log_prob(tokens) >>> samples = dist.sample() # [B, T]
- 标记级掩码
>>> logits = torch.randn(2, 10, 50000) >>> mask = torch.ones(2, 10, 50000, dtype=torch.bool) # [B, T, C] >>> mask[0, :5, :1000] = False # mask first 1000 tokens for first 5 positions >>> dist = LLMMaskedCategorical(logits=logits, mask=mask) >>> tokens = torch.randint(0, 50000, (2, 10)) >>> # Optionally, set tokens at fully-masked positions to ignore_index >>> log_probs = dist.log_prob(tokens) >>> samples = dist.sample() # [B, T]
注意事项
对于 log_prob,tokens 的形状必须为 [B, T],并且包含有效的标记索引(0 <= token < C),或为已掩码/忽略的位置使用 ignore_index。
对于标记级掩码,如果某个位置的标记被掩码,则该条目的 log_prob 将返回 -inf。
对于位置级掩码,如果某个位置被掩码(ignore_index),则该条目的 log_prob 将返回 0.0(对于交叉熵损失正确)。
采样始终遵循掩码(被掩码的标记/位置从不被采样)。
所有记录的用例都由 test_distributions.py 中的测试覆盖。
- log_prob(value: Tensor) Tensor [源代码]¶
使用 ignore_index 方法计算对数概率。
这是一种内存高效的方法,因为它不需要掩码 logits。value 张量应使用 ignore_index 来表示已掩码的位置。
- property masked_dist: Categorical¶
获取用于采样操作的掩码分布。
- property position_level_masking: bool¶
掩码是位置级的(True)还是标记级的(False)。
- rsample(sample_shape: torch.Size | Sequence[int] | None = None) Tensor [源代码]¶
使用掩码 logits 进行重参数化采样。
- sample(sample_shape: torch.Size | Sequence[int] | None = None) Tensor [源代码]¶
使用掩码 logits 从分布中采样。