快捷方式

LLMMaskedCategorical

class torchrl.modules.LLMMaskedCategorical(logits: Tensor, mask: Tensor, ignore_index: int = - 100)[源代码]

LLM 优化的掩码分类分布。

此类通过以下方式为 LLM 训练提供更内存高效的方法: 1. 使用 ignore_index=-100 进行 log_prob 计算(无掩码开销) 2. 使用传统掩码进行采样操作

这对于掩码所有 logits 会占用大量内存的大词汇量特别有利。

参数:
  • logits (torch.Tensor) – 事件的对数概率(未归一化),形状为 [B, T, C]。 - B:批次大小(可选) - T:序列长度 - C:词汇量大小(类别数量)

  • mask (torch.Tensor) –

    布尔掩码,指示有效位置/标记。 - 如果形状为 [*B, T]:位置级掩码。True 表示位置有效(所有标记都允许)。 - 如果形状为 [*B, T, C]:标记级掩码。True 表示在该位置标记有效。

    警告

    标记级掩码比位置级掩码占用更多内存。只有在需要掩码标记时才使用此选项。

  • ignore_index (int, optional) – log_prob 计算中要忽略的索引。默认为 -100。

输入形状
  • logits: [*B, T, C](必需)

  • mask: [*B, T](位置级)或 [*B, T, C](标记级)

  • tokens (用于 log_prob):[*B, T](标记索引,已掩码/忽略的位置使用 ignore_index)

用例
  1. 位置级掩码
    >>> logits = torch.randn(2, 10, 50000)  # [B=2, T=10, C=50000]
    >>> mask = torch.ones(2, 10, dtype=torch.bool)  # [B, T]
    >>> mask[0, :5] = False  # mask first 5 positions of first sequence
    >>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
    >>> tokens = torch.randint(0, 50000, (2, 10))  # [B, T]
    >>> tokens[0, :5] = -100  # set masked positions to ignore_index
    >>> log_probs = dist.log_prob(tokens)
    >>> samples = dist.sample()  # [B, T]
    
  2. 标记级掩码
    >>> logits = torch.randn(2, 10, 50000)
    >>> mask = torch.ones(2, 10, 50000, dtype=torch.bool)  # [B, T, C]
    >>> mask[0, :5, :1000] = False  # mask first 1000 tokens for first 5 positions
    >>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
    >>> tokens = torch.randint(0, 50000, (2, 10))
    >>> # Optionally, set tokens at fully-masked positions to ignore_index
    >>> log_probs = dist.log_prob(tokens)
    >>> samples = dist.sample()  # [B, T]
    

注意事项

  • 对于 log_prob,tokens 的形状必须为 [B, T],并且包含有效的标记索引(0 <= token < C),或为已掩码/忽略的位置使用 ignore_index。

  • 对于标记级掩码,如果某个位置的标记被掩码,则该条目的 log_prob 将返回 -inf。

  • 对于位置级掩码,如果某个位置被掩码(ignore_index),则该条目的 log_prob 将返回 0.0(对于交叉熵损失正确)。

  • 采样始终遵循掩码(被掩码的标记/位置从不被采样)。

所有记录的用例都由 test_distributions.py 中的测试覆盖。

clear_cache()[源代码]

清除缓存的掩码张量以释放内存。

entropy() Tensor[源代码]

使用掩码 logits 计算熵。

log_prob(value: Tensor) Tensor[源代码]

使用 ignore_index 方法计算对数概率。

这是一种内存高效的方法,因为它不需要掩码 logits。value 张量应使用 ignore_index 来表示已掩码的位置。

property logits: Tensor

获取原始 logits。

property mask: Tensor

获取掩码。

property masked_dist: Categorical

获取用于采样操作的掩码分布。

property masked_logits: Tensor

获取用于采样操作的掩码 logits。

property mode: Tensor

使用掩码 logits 获取模式。

property position_level_masking: bool

掩码是位置级的(True)还是标记级的(False)。

property probs: Tensor

从原始 logits 获取概率。

rsample(sample_shape: torch.Size | Sequence[int] | None = None) Tensor[源代码]

使用掩码 logits 进行重参数化采样。

sample(sample_shape: torch.Size | Sequence[int] | None = None) Tensor[源代码]

使用掩码 logits 从分布中采样。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源