快捷方式

LLMMaskedCategorical

class torchrl.modules.LLMMaskedCategorical(logits: Tensor, mask: Tensor, ignore_index: int = - 100)[源代码]

LLM 优化的掩码分类分布。

此类通过以下方式为 LLM 训练提供了更节省内存的方法:1. 在 log_prob 计算中使用 ignore_index=-100(无掩码开销)2. 在采样操作中使用传统掩码

这对于词汇量大的情况特别有益,因为掩盖所有 logits 可能会占用大量内存。

参数:
  • logits (torch.Tensor) – 事件对数概率(未归一化),形状为 [B, T, C]。 - B:批次大小(可选) - T:序列长度 - C:词汇量大小(类别数)

  • mask (torch.Tensor) –

    布尔掩码,指示有效位置/标记。 - 如果形状为 [*B, T]:位置级掩码。True 表示该位置有效(所有标记都允许)。 - 如果形状为 [*B, T, C]:标记级掩码。True 表示该标记在该位置有效。

    警告

    标记级掩码比位置级掩码占用更多内存。仅在需要掩盖标记时使用。

  • ignore_index (int, optional) – log_prob 计算中要忽略的索引。默认为 -100。

输入形状
  • logits: [*B, T, C](必需)

  • mask: [*B, T](位置级)或 [*B, T, C](标记级)

  • tokens (for log_prob): [*B, T](标记索引,带 ignore_index 用于被掩盖的位置)

使用场景
  1. 位置级掩码
    >>> logits = torch.randn(2, 10, 50000)  # [B=2, T=10, C=50000]
    >>> mask = torch.ones(2, 10, dtype=torch.bool)  # [B, T]
    >>> mask[0, :5] = False  # mask first 5 positions of first sequence
    >>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
    >>> tokens = torch.randint(0, 50000, (2, 10))  # [B, T]
    >>> tokens[0, :5] = -100  # set masked positions to ignore_index
    >>> log_probs = dist.log_prob(tokens)
    >>> samples = dist.sample()  # [B, T]
    
  2. 标记级掩码
    >>> logits = torch.randn(2, 10, 50000)
    >>> mask = torch.ones(2, 10, 50000, dtype=torch.bool)  # [B, T, C]
    >>> mask[0, :5, :1000] = False  # mask first 1000 tokens for first 5 positions
    >>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
    >>> tokens = torch.randint(0, 50000, (2, 10))
    >>> # Optionally, set tokens at fully-masked positions to ignore_index
    >>> log_probs = dist.log_prob(tokens)
    >>> samples = dist.sample()  # [B, T]
    

注意事项

  • 对于 log_prob,tokens 的形状必须为 [B, T],并包含有效的标记索引(0 <= token < C),或者对于被掩盖/忽略的位置使用 ignore_index。

  • 对于标记级掩码,如果某个位置的标记被掩盖,log_prob 将为该条目返回 -inf。

  • 对于位置级掩码,如果某个位置被掩盖(ignore_index),log_prob 将为该条目返回 0.0(对于交叉熵损失是正确的)。

  • 采样始终遵守掩码(被掩盖的标记/位置永远不会被采样)。

所有文档化的使用场景均包含在 test_distributions.py 中的测试中。

clear_cache()[源代码]

清除缓存的掩码张量以释放内存。

entropy() Tensor[源代码]

使用掩码 logits 计算熵。

log_prob(value: Tensor) Tensor[源代码]

使用 ignore_index 方法计算对数概率。

这很节省内存,因为它不需要掩盖 logits。value 张量应为被掩盖的位置使用 ignore_index。

property logits: Tensor

获取原始 logits。

property mask: Tensor

获取掩码。

property masked_dist: Categorical

获取用于采样操作的掩码分布。

property masked_logits: Tensor

获取用于采样操作的掩码 logits。

property mode: Tensor

使用掩码 logits 获取模式。

property position_level_masking: bool

掩码是位置级的(True)还是标记级的(False)。

property probs: Tensor

从原始 logits 获取概率。

rsample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor[源代码]

使用掩码 logits 进行重参数化采样。

sample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor[源代码]

使用掩码 logits 从分布中采样。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源