PyTorch 的 nn.NLLLoss
:负对数似然损失全解析
在 PyTorch 的损失函数家族中,nn.NLLLoss
(Negative Log Likelihood Loss,负对数似然损失)是一个不太起眼但非常重要的成员。它经常跟 LogSoftmax
搭配出现,尤其在分类任务中扮演关键角色。今天我们就来聊聊 nn.NLLLoss
的数学原理、使用方法,以及它适用的场景,带你彻底搞懂这个损失函数。
1. 什么是负对数似然损失?
先从名字拆解:
- 似然(Likelihood):在统计学中,似然表示“给定模型参数时,观察到数据的概率”。对数似然(Log Likelihood)是它的对数形式,常用于简化计算。
- 负对数似然(Negative Log Likelihood, NLL):把对数似然取负数,作为损失函数,目标是最小化它。
在机器学习中,负对数似然损失通常用来衡量模型预测的概率分布与真实标签的差距,尤其是在分类任务中。
数学公式
假设我们有一个多分类任务,有 ( C C C ) 个类别。对于一个样本:
- ( y ^ \hat{y} y^ ) 是模型输出的概率分布,比如经过 Softmax 或 LogSoftmax 处理后的结果。
- ( y y y ) 是真实类别,用索引表示(比如 2 表示第 2 类)。
nn.NLLLoss
的公式是:
NLL = − 1 N ∑ i = 1 N log ( y ^ i , y i ) \text{NLL} = -\frac{1}{N} \sum_{i=1}^{N} \log(\hat{y}_{i, y_i}) NLL=−N1i=1∑Nlog(y^i,yi)
- ( N N N ):样本数量(batch size)。
- ( y i y_i yi ):第 ( i i i ) 个样本的真实类别索引。
- ( y ^ i , y i \hat{y}_{i, y_i} y^i,yi ):第 ( i i i ) 个样本在真实类别 ( y i y_i yi ) 上的预测概率(对数值)。
简单来说,nn.NLLLoss
取预测概率的对数(已经由 LogSoftmax
计算好),然后取负号,只关心正确类别的概率值。
2. 为什么搭配 LogSoftmax
?
你可能会注意到,nn.NLLLoss
的文档里总是提到“通常与 LogSoftmax
搭配使用”。这是为什么?
- 模型输出:神经网络的最后一层通常输出未归一化的 logits(比如
[1.0, 2.0, 0.5]
),而不是概率。 - Softmax:将 logits 转为概率分布,比如
[0.2, 0.5, 0.3]
,满足 ( ∑ y ^ = 1 \sum \hat{y} = 1 ∑y^=1)。公式是:
y ^ j = e z j ∑ k = 1 C e z k \hat{y}_j = \frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}} y^j=∑k=1Cezkezj - LogSoftmax:在 Softmax 基础上取对数,输出的是对数概率,比如
[-1.6, -0.7, -1.2]
。公式是:
log ( y ^ j ) = z j − log ( ∑ k = 1 C e z k ) \log(\hat{y}_j) = z_j - \log(\sum_{k=1}^{C} e^{z_k}) log(y^j)=zj−log(k=1∑Cezk)
nn.NLLLoss
要求输入是对数概率(log probabilities),而不是原始概率。所以:
- 如果你直接给它 Softmax 后的概率,会出错,因为它期待的是 ( log ( y ^ ) \log(\hat{y}) log(y^))。
- 用
LogSoftmax
处理后,输入正好符合要求,计算时直接取负号即可。
3. 代码使用示例
我们来看一个简单的例子,展示 nn.NLLLoss
和 LogSoftmax
的搭配:
import torch
import torch.nn as nn# 假设一个 3 分类任务,batch_size = 2
logits = torch.tensor([[1.0, 2.0, 0.5], [0.1, 0.5, 2.0]]) # 原始 logits
target = torch.tensor([1, 2]) # 真实类别索引,0~2# 定义 LogSoftmax 和 NLLLoss
log_softmax = nn.LogSoftmax(dim=1) # dim=1 表示在类别维度上归一化
loss_fn = nn.NLLLoss()# 计算损失
log_probs = log_softmax(logits) # 先转为对数概率
loss = loss_fn(log_probs, target)
print("NLL Loss:", loss.item())
运行过程:
logits
是[batch_size, num_classes]
的张量,表示每个样本在每个类别上的得分。nn.LogSoftmax
把 logits 转为对数概率,比如[[-1.9, -0.9, -2.4], [-2.3, -1.9, -0.4]]
。nn.NLLLoss
提取每个样本在真实类别上的对数概率(比如第一个样本取-0.9
,第二个取-0.4
),取负并平均。
输出可能是 1.15
,具体值取决于输入。
4. 与 nn.CrossEntropyLoss
的关系
你可能听说过 nn.CrossEntropyLoss
,它也很常见。实际上:
nn.CrossEntropyLoss
=LogSoftmax
+nn.NLLLoss
PyTorch 把这两步合二为一,直接接受 logits 作为输入,内部自动完成 LogSoftmax 和 NLL 计算。具体过程可以参考笔者的另一篇博客:Pytorch为什么 nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss?
代码对比:
# 用 nn.CrossEntropyLoss
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss = ce_loss_fn(logits, target)
print("CrossEntropyLoss:", ce_loss.item()) # 与上面结果相同
- 区别:
nn.NLLLoss
:输入是对数概率,需手动加LogSoftmax
。nn.CrossEntropyLoss
:输入是 logits,自动处理。
5. 使用场景
nn.NLLLoss
适用于以下场景:
- 多分类任务:比如图像分类(CIFAR-10 的 10 类)、文本分类。
- 需要分离 Softmax 的情况:
- 你想在模型里显式控制 LogSoftmax 的位置,而不是交给损失函数。
- 调试时单独检查对数概率的值。
- 概率输出的模型:如果你的模型已经输出对数概率(比如某些预训练模型),直接用
nn.NLLLoss
更高效。
典型例子:
- 一个简单的 CNN 分类器:
这里模型输出对数概率,搭配class SimpleCNN(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(1, 16, 3)self.fc = nn.Linear(16 * 26 * 26, 10) # 假设 28x28 输入self.log_softmax = nn.LogSoftmax(dim=1)def forward(self, x):x = self.conv(x)x = x.view(x.size(0), -1)x = self.fc(x)return self.log_softmax(x)model = SimpleCNN() loss_fn = nn.NLLLoss()
nn.NLLLoss
计算损失。
6. 注意事项
- 输入形状:
- 输入:
[batch_size, num_classes]
(对数概率)。 - 目标:
[batch_size]
(类别索引)。
- 输入:
- 目标类型:必须是整数(long 类型),不能是 one-hot 或浮点数。
- 数值稳定性:
LogSoftmax
比单独的Softmax + log
更稳定,因为它避免了溢出问题。
7. 小结:nn.NLLLoss
的核心
- 数学原理:计算正确类别对数概率的负值,最小化它等价于最大化似然。
- 使用方式:搭配
LogSoftmax
,输入对数概率,输出标量损失。 - 场景:多分类任务,尤其是需要显式控制概率计算时。
- 与
CrossEntropyLoss
的关系:前者是后者的组成部分,功能更模块化。
nn.NLLLoss
就像一个“半成品”,需要你自己搭配 LogSoftmax
,但这也给了你更多灵活性。相比直接用 nn.CrossEntropyLoss
,它更适合喜欢拆解步骤或调试模型的开发者。
8. 调试小技巧
- 检查输入:打印
log_probs
确保是对数概率(负值)。 - 验证目标:确保
target
是整数,且范围在[0, num_classes-1]
。 - 对比结果:用
nn.CrossEntropyLoss
验证是否一致。
希望这篇博客让你对 nn.NLLLoss
有了全面认识!
后记
2025年2月28日18点59分于上海,在Grok3大模型辅助下完成。