评价此页

torch.nn.functional.ctc_loss#

torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[源代码]#

计算连接主义时序分类(Connectionist Temporal Classification)损失。

详情请参阅 CTCLoss

注意

在某些情况下,当在 CUDA 设备上使用张量并利用 CuDNN 时,此算子可能会选择一个非确定性算法来提高性能。如果这不可取,你可以尝试将操作设置为确定性的(可能以性能为代价),方法是设置 torch.backends.cudnn.deterministic = True。有关更多信息,请参阅 可复现性

注意

此操作在使用 CUDA 设备上的张量时可能会产生非确定性梯度。有关更多信息,请参阅 可复现性

参数
  • log_probs (Tensor) – (T,N,C)(T, N, C)(T,C)(T, C),其中 C = 字母表中的字符数(包括空白字符)T = 输入长度N = 批次大小。输出的对数概率(例如,通过 torch.nn.functional.log_softmax() 获得)。

  • targets (Tensor) – (N,S)(N, S)(sum(target_lengths))。如果 target_lengths 中的所有条目都为零,则可能为空张量。在第二种形式中,目标被假定为已连接。

  • input_lengths (Tensor) – (N)(N)()()。输入的长度(每个都必须 T\leq T)

  • target_lengths (Tensor) – (N)(N)()()。目标的长度

  • blank (int, optional) – 空白标签。默认为 00

  • reduction (str, optional) – 指定应用于输出的规约:'none' | 'mean' | 'sum''none':不进行规约,'mean':输出损失将除以目标长度,然后取批次上的均值,'sum':输出将被求和。默认为 'mean'

  • zero_infinity (bool, optional) – 是否将无穷损失及其相关梯度归零。默认为 False。无穷损失主要发生在输入太短而无法与目标对齐时。

返回类型

张量

示例

>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
>>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
>>> input_lengths = torch.full((16,), 50, dtype=torch.long)
>>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
>>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
>>> loss.backward()