next token prediction 是怎么计算损失的呢?
最近重看LLMs, 发现看损失函数又卡了一会儿,故记录一下。
【注意,下文描述的词,都是token的意思,方便表述】
推导过程
1 .优化目标
想一想我们生成的目的,那肯定是希望,尽可能生成我们预期的下一个词,
这里预期的下一个词是什么呢?不就是原来的文本吗?
我们这里使用文本来进行自监督学习,所以是知道实际值的
模型输出是什么呢?是一个logits(取softmax之后就是概率分布),所以可以说里面有每个词的概率。
我们取的是什么呢?概率最大的词,也就是argmax(logits)。
为了能取到这个词,我们就要让这个词的概率最大。
好了,我们目前优化的目标就有了。
2 .Loss设计
具体的Loss怎么设计,才能使得朝着我们想要的方向优化?
首先,目前的结果是一些概率。
为什么是一些,这里有n(上下文长度)个输入token,模型的输出也是n个token:[a,a+n-1]-> [a+1,a+n])
这里很容易混淆,认为一次推理,只会预测后一个token,实际上不是,前面的 n-1个token也会预测。
我们对这些概率取log (在数学优化过程中,处理概率得分的对数比直接处理得分本身更为便捷)
- 通过对概率值取对数,可以将乘法转换为加法,从而避免下溢问题
- 对数函数的梯度变化更加平滑,有助于优化算法的稳定性
ok,由于目前有n个数,我们可以取一个平均值,这样就只有一个值了。
为什么取平均?
它避免了序列长度、批次大小等因素对损失值的干扰,有助于更公平地评估模型性能,并使得模型训练更加一致和稳定。
注意到,之前的概率我们希望是1,取log之后,我们希望是0(log 1 = 0),而且目标是从负数增大0。(想象下log函数图像就理解了,x∈(0,1),logx∈(-∞,1))
这里将loss从负数优化到0,有点奇怪。
因为在深度学习中,常见的做法并不是提升平均对数概率至 0,而是降低负平均对数概率至 0。
负平均对数概率:平均对数概率乘以 -1
所以我们最后再乘以-1即可。
这里模型输出,是把词表中每个词的概率都输出了,类似于多分类任务,我们可以使用交叉熵损失函数。
图示
下图是完整的流程:
1->2: 取了softmax之后,变成了概率。
3: 是我们要预测的每一个token的现在的概率(模型此时的输出)
代码
logits 的维度通常是 [batch_size * seq_len, vocab_size]
- batch_size * seq_len:这是所有序列中所有 token 的数量,也就是将输入序列展平后的 token 数量。
- vocab_size:每个 token 对应的类别数,即词汇表的大小。
targets_flat 的维度是 [batch_size * seq_len]
- 每个元素是一个整数,表示真实标签对应的类别索引(词汇表中的索引)
loss = torch.nn.functional.cross_entropy(logits_flat, targets_flat)
请注意,目标(targets)是标记ID,它们也代表了我们希望在logits张量中最大化的索引位置。
PyTorch中的cross_entropy函数会自动地将softmax和对数概率计算应用到这些要最大化标记索引的logits上。