欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 汽车 > 维修 > 【llm对话系统】大模型源码分析之llama kv cache缓存逻辑

【llm对话系统】大模型源码分析之llama kv cache缓存逻辑

2025/2/7 6:58:24 来源:https://blog.csdn.net/kakaZhui/article/details/145393435  浏览:    关键词:【llm对话系统】大模型源码分析之llama kv cache缓存逻辑

在大型语言模型(LLM)的推理过程中,为了提高生成速度,通常会采用一种名为 KV Cache 的优化技术。KV Cache 可以缓存中间计算结果,避免重复计算,从而显著提升推理效率。本文将深入探讨 LLaMA 模型中 KV Cache 的实现逻辑,包括训练和推理阶段的具体操作。

1. 什么是 KV Cache

1.1 为什么需要 KV Cache

在自回归模型中,模型会逐词生成文本。每生成一个新词,都需要进行一次完整的 Transformer 前向计算,包括计算注意力权重。然而,对于已经生成的词,它们的 Key (K) 和 Value (V) 向量在计算注意力时会被重复使用。

KV Cache 的作用就是缓存已经计算过的 Key 和 Value 向量,当生成新的词时,只需要计算新的 Query 向量,并使用缓存的 K 和 V 向量进行注意力计算,从而避免了对整个序列的重复计算。

1.2 KV Cache 的工作原理

具体来说,KV Cache 会缓存每一层 Transformer 编码器或解码器中的 K 和 V 向量。在推理过程中:

  1. 首次计算: 模型会计算完整输入序列的 K、V 和 Q 向量。
  2. 缓存 K 和 V: K 和 V 向量被保存到 KV Cache 中。
  3. 后续计算: 当模型需要生成下一个词时,只需要计算新生成词对应的 Q 向量。然后,将新 Q 向量与 KV Cache 中缓存的 K 和 V 向量进行注意力计算。

2. LLaMA 中的 KV Cache

LLaMA 模型在推理阶段使用了 KV Cache 来加速文本生成。

2.1 LLaMA 的实现逻辑

LLaMA 的 KV Cache 实现主要包括以下步骤:

  1. 初始化: 在推理开始时,KV Cache 被初始化为空。
  2. 计算 K 和 V: 当模型接收输入序列时,会计算每一层的 K 和 V 向量。
  3. 缓存 K 和 V: 将计算得到的 K 和 V 向量保存到 KV Cache 中。
  4. 更新 KV Cache: 当生成新的词时,只需要计算新词的 K 和 V 向量,并追加到 KV Cache 中。
  5. 注意力计算: 在进行注意力计算时,从 KV Cache 中取出 K 和 V 向量,与新的 Q 向量进行计算。

2.2 LLaMA 源码示例 (PyTorch)

以下是 LLaMA 模型中 KV Cache 的核心代码(简化版):

import torch
import torch.nn as nn
import mathclass LlamaAttention(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.head_dim = d_model // num_headsself.Wq = nn.Linear(d_model, d_model)self.Wk = nn.Linear(d_model, d_model)self.Wv = nn.Linear(d_model, d_model)self.Wo = nn.Linear(d_model, d_model)def forward(self, x, kv_cache=None, use_cache=False):batch_size, seq_len, _ = x.size()q = self.Wq(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)k = self.Wk(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)v = self.Wv(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)if kv_cache is not None: # 使用KV Cacheprev_k, prev_v = kv_cachek = torch.cat([prev_k, k], dim=2)v = torch.cat([prev_v, v], dim=2)if use_cache: # 如果需要缓存current_kv_cache = (k,v)else:current_kv_cache = Nonescores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)attn_weights = torch.softmax(scores, dim=-1)attn_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)output = self.Wo(attn_output)return output, current_kv_cache# 示例
d_model = 512
num_heads = 8
batch_size = 2
seq_len = 5
attention_layer = LlamaAttention(d_model, num_heads)
input_tensor = torch.randn(batch_size, seq_len, d_model)
kv_cache = None
for i in range(10): # 模拟生成10个词output, kv_cache  = attention_layer(input_tensor, kv_cache=kv_cache, use_cache=True)print(f"Generated word {i+1}, Output shape: {output.shape}")input_tensor = torch.randn(batch_size, 1, d_model) # 假设每次生成一个词

代码解释:

  1. LlamaAttention 类:
    • forward 方法中,接受 kv_cacheuse_cache 参数。
    • 如果 kv_cache 不为空,则从缓存中获取之前的 kv,并与当前输入计算出的 kv 进行拼接。
    • 如果 use_cacheTrue,则返回当前的 kv,供后续使用。
    • 其余计算和没有KV Cache时一样。
  2. 示例:
    • 模拟生成10个词的过程,输入tensor每次的长度都为1,代表每次生成一个词。
    • 每次循环会更新kv_cache的值,并使用上次的kv_cache。

2.3 KV Cache 的更新逻辑

在推理阶段,KV Cache 需要随着新词的生成而不断更新。当生成一个新词时:

  1. 计算新的 K 和 V: 使用新词的输入向量计算对应的 K 和 V 向量。
  2. 追加到 Cache: 将新的 K 和 V 向量追加到 KV Cache 中。
  3. 后续计算: 在后续生成词的过程中,会使用更新后的 KV Cache。

3. 训练阶段的 KV Cache

在训练阶段,通常不需要使用 KV Cache,因为训练时会一次性输入整个序列。但是,为了保持训练和推理的一致性,有些实现可能会在训练阶段也计算并保存 KV Cache,但通常不会在注意力计算中使用它。

4. KV Cache 的优势

使用 KV Cache 能够带来以下优势:

  1. 推理加速: 避免重复计算已经生成的词的 K 和 V 向量,从而显著提高推理速度。
  2. 内存占用优化: 虽然 KV Cache 会占用一部分内存,但相对于重复计算而言,总体的内存占用是可控的。
  3. 资源利用率提高: 减少了模型计算量,提高了硬件资源利用率。

5. 总结

本文详细介绍了 LLaMA 模型中 KV Cache 的实现原理和使用方法。通过源码分析,我们了解了 KV Cache 在推理加速中的作用。希望本文能帮助你更好地理解大型语言模型中的优化技术。

6. 参考资料

  • Transformer Language Models
  • Efficiently Scaling Transformer Inference

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com