torch.nn.functional.ctc_loss#
- torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[source]#
计算连接主义时序分类(CTC)损失。
有关详细信息,请参阅
CTCLoss
。注意
在某些情况下,当在 CUDA 设备上使用张量并使用 CuDNN 时,此运算符可能会选择一个非确定性算法以提高性能。如果这不可取,您可以尝试通过设置
torch.backends.cudnn.deterministic = True
来使操作确定性(可能以牺牲性能为代价)。有关更多信息,请参阅 可复现性。注意
此操作在使用 CUDA 设备上的张量时可能会产生非确定性梯度。有关更多信息,请参阅 可复现性。
- 参数
log_probs (Tensor) – or ,其中 C = 字母表中字符的数量(包括空格),T = 输入长度,N = 批次大小。输出的对数概率(例如,通过
torch.nn.functional.log_softmax()
获得)。targets (Tensor) – 或 (sum(target_lengths))。如果 target_lengths 中的所有条目都为零,则可能是一个空张量。在第二种形式中,目标被假定为连接的。
input_lengths (Tensor) – 或 。输入的长度(每个长度必须 )
target_lengths (Tensor) – 或 。目标的长度
blank (int, optional) – 空白标签。默认为 。
reduction (str, optional) – 指定应用于输出的归约方式:
'none'
|'mean'
|'sum'
。'none'
:不应用任何归约,'mean'
:输出损失将除以目标长度,然后计算批次上的平均值,'sum'
:将对输出求和。默认为'mean'
zero_infinity (bool, optional) – 是否将无限损失及其相关的梯度归零。默认为
False
。当输入太短而无法与目标对齐时,主要会出现无限损失。
- 返回类型
示例
>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) >>> input_lengths = torch.full((16,), 50, dtype=torch.long) >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> loss.backward()