欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 资讯 > LLMs-NTP损失函数推导

LLMs-NTP损失函数推导

2024/12/31 1:46:47 来源:https://blog.csdn.net/m0_46252199/article/details/144410752  浏览:    关键词:LLMs-NTP损失函数推导

next token prediction 是怎么计算损失的呢?

最近重看LLMs, 发现看损失函数又卡了一会儿,故记录一下。

【注意,下文描述的词,都是token的意思,方便表述】

推导过程

1 .优化目标

想一想我们生成的目的,那肯定是希望,尽可能生成我们预期的下一个词,
这里预期的下一个词是什么呢?不就是原来的文本吗?

我们这里使用文本来进行自监督学习,所以是知道实际值的

模型输出是什么呢?是一个logits(取softmax之后就是概率分布),所以可以说里面有每个词的概率。

我们取的是什么呢?概率最大的词,也就是argmax(logits)。
为了能取到这个词,我们就要让这个词的概率最大。
好了,我们目前优化的目标就有了。

2 .Loss设计

具体的Loss怎么设计,才能使得朝着我们想要的方向优化?

首先,目前的结果是一些概率。

为什么是一些,这里有n(上下文长度)个输入token,模型的输出也是n个token:[a,a+n-1]-> [a+1,a+n])
这里很容易混淆,认为一次推理,只会预测后一个token,实际上不是,前面的 n-1个token也会预测。

我们对这些概率取log (在数学优化过程中,处理概率得分的对数比直接处理得分本身更为便捷)

  1. 通过对概率值取对数,可以将乘法转换为加法,从而避免下溢问题
  2. 对数函数的梯度变化更加平滑,有助于优化算法的稳定性

ok,由于目前有n个数,我们可以取一个平均值,这样就只有一个值了。

为什么取平均?
它避免了序列长度、批次大小等因素对损失值的干扰,有助于更公平地评估模型性能,并使得模型训练更加一致和稳定。

注意到,之前的概率我们希望是1,取log之后,我们希望是0(log 1 = 0),而且目标是从负数增大0。(想象下log函数图像就理解了,x∈(0,1),logx∈(-∞,1))

这里将loss从负数优化到0,有点奇怪。

因为在深度学习中,常见的做法并不是提升平均对数概率至 0,而是降低负平均对数概率至 0。
负平均对数概率:平均对数概率乘以 -1

所以我们最后再乘以-1即可。

这里模型输出,是把词表中每个词的概率都输出了,类似于多分类任务,我们可以使用交叉熵损失函数。

图示

下图是完整的流程:
1->2: 取了softmax之后,变成了概率。
3: 是我们要预测的每一个token的现在的概率(模型此时的输出)
在这里插入图片描述

代码

logits 的维度通常是 [batch_size * seq_len, vocab_size]

  • batch_size * seq_len:这是所有序列中所有 token 的数量,也就是将输入序列展平后的 token 数量。
  • vocab_size:每个 token 对应的类别数,即词汇表的大小。

targets_flat 的维度是 [batch_size * seq_len]

  • 每个元素是一个整数,表示真实标签对应的类别索引(词汇表中的索引)
loss = torch.nn.functional.cross_entropy(logits_flat, targets_flat)

请注意,目标(targets)是标记ID,它们也代表了我们希望在logits张量中最大化的索引位置。
PyTorch中的cross_entropy函数会自动地将softmax和对数概率计算应用到这些要最大化标记索引的logits上。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com