欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 国际 > 【论文解读】TransMLA: Multi-Head Latent Attention Is All You Need

【论文解读】TransMLA: Multi-Head Latent Attention Is All You Need

2025/2/24 11:56:07 来源:https://blog.csdn.net/qq_30731313/article/details/145816566  浏览:    关键词:【论文解读】TransMLA: Multi-Head Latent Attention Is All You Need

论文链接

1. 论文背景与问题动机

现代大规模语言模型(LLM)在推理时往往遇到通信瓶颈,主要原因在于自注意力机制中需要缓存大量的 Key-Value(KV)对。例如,对于 LLaMA‑65B 这种模型,即使采用 8 位量化,存储 512K 个 token 的 KV 缓存也需要超过 86GB 的 GPU 内存,这远远超出了单个高端 GPU 的容量。为了降低 KV 缓存带来的内存与通信开销,许多方法被提出,例如 Multi‑Query Attention (MQA) 和 Group Query Attention (GQA);然而,这些方法虽然降低了缓存需求,但通常会牺牲模型性能。论文正是在这种背景下,提出了一种新的注意力机制——多头潜变量注意力(Multi‑Head Latent Attention, MLA),旨在在不增加 KV 缓存开销的前提下提高模型表达能力。


2. 基本概念与传统注意力机制

在这里插入图片描述

2.1 Multi‑Head Attention (MHA)

  • 输入与权重变换
    给定输入序列 X ∈ R T × D X \in \mathbb{R}^{T \times D} XRT×D(T 为序列长度,D 为隐藏维度),通过三个权重矩阵 W Q W_Q WQ W K W_K WK W V W_V WV 将输入分别映射为 Query、Key 和 Value:
    Q = X W Q , K = X W K , V = X W V . Q = XW_Q,\quad K = XW_K,\quad V = XW_V. Q=XWQ,K=XWK,V=XWV.

  • 多头拆分
    Q Q Q K K K V V V 分成 n h n_h nh 个头,每个头的维度为 d h d_h dh
    Q = [ Q 1 ; Q 2 ; ⋯ ; Q n h ] , (其他同理) Q = \begin{bmatrix} Q_1; Q_2; \cdots; Q_{n_h} \end{bmatrix}, \quad \text{(其他同理)} Q=[Q1;Q2;;Qnh],(其他同理)

  • 注意力计算
    每个头计算注意力得分:
    O i = softmax ( Q i K i T d h ) V i W O i , O_i = \text{softmax}\Bigl(\frac{Q_iK_i^T}{\sqrt{d_h}}\Bigr)V_iW_{O_i}, Oi=softmax(dh QiKiT)ViWOi,
    最终将所有头的输出相加得到最终结果 O O O

2.2 Group Query Attention (GQA)

  • 思想
    GQA 的目标是降低 KV 缓存的开销,它将所有 Query 头分成若干组,每组共享同一个 Key 和 Value。设总共 n q n_q nq 个 Query 头,Key 与 Value 只用 n k n_k nk 个头(其中 n k < n q n_k < n_q nk<nq),并通过复制(replication)操作使得每个 Query 都能匹配到 Key。

  • 具体操作

    1. 将输入 X X X 分别映射为 Q ∈ R T × ( n q ⋅ d h ) Q\in\mathbb{R}^{T\times(n_q\cdot d_h)} QRT×(nqdh) K ∈ R T × ( n k ⋅ d h ) K\in\mathbb{R}^{T\times(n_k\cdot d_h)} KRT×(nkdh)(以及 V V V)。
    2. 为了使得 Query 与 Key 数量匹配,需要对 Key 进行复制:设复制因子为 s = n q n k s = \frac{n_q}{n_k} s=nknq,将每个 Key 头复制 s s s 次,拼接得到扩展后的 Key 矩阵 K ′ K' K
  • 特殊情况
    n k = n q n_k = n_q nk=nq 时,GQA 恢复为标准 MHA;当 n k = 1 n_k = 1 nk=1 时,即为 Multi‑Query Attention (MQA)。

2.3 Multi‑Head Latent Attention (MLA)

  • 核心思路
    MLA 提出用低秩矩阵因子分解来近似传统的复制操作。其主要思想是将 Key 层的投影矩阵进行因子分解,从而只需缓存一个低维的“潜变量”表示,再通过一个上投影矩阵恢复完整表示。

  • 具体实现
    设:

    • W Q ∈ R D × ( n h ⋅ d h ) W_Q \in \mathbb{R}^{D\times(n_h\cdot d_h)} WQRD×(nhdh) 用于生成 Query;
    • 对于 Key 和 Value 层,分别使用两个矩阵 W K a W^a_K WKa(或 W V a W^a_V WVa)和 W K b W^b_K WKb(或 W V b W^b_V WVb),其中 W K a ∈ R D × r W^a_K \in \mathbb{R}^{D\times r} WKaRD×r 将输入映射到一个低维表示( r r r 远小于 n h ⋅ d h n_h\cdot d_h nhdh),而 W K b ∈ R r × ( n h ⋅ d h ) W^b_K \in \mathbb{R}^{r\times(n_h\cdot d_h)} WKbRr×(nhdh) 负责扩充回原来的维度。

    具体计算为:
    Q = X W Q , K = X W K a W K b , V = X W V a W V b . Q = XW_Q, \quad K = XW^a_K W^b_K, \quad V = XW^a_V W^b_V. Q=XWQ,K=XWKaWKb,V=XWVaWVb.

    这样,在推理时只需要缓存 X W K a XW^a_K XWKa X W V a XW^a_V XWVa(低维表示),而不必保存完整的高维 Key 和 Value,从而大大降低了 KV 缓存的存储开销。


3. TransMLA:从 GQA 到 MLA 的转换

论文在理论上证明了以下定理

定理 1: 当 KV 缓存大小相同时,MLA 的表达能力严格大于 GQA 的表达能力。

为证明这一点,论文从三个方面展开讨论:

3.1 GQA 中的 Key 复制

  • 过程说明
    在 GQA 中,假设输入经过 W K W_K WK 得到 Key K K K(具有 n k n_k nk 个头,每个头维度为 d h d_h dh),为了匹配 n q n_q nq 个 Query 头,需要对每个 Key 头进行复制。具体地,将 K K K 按照列分成 n k n_k nk 个块,然后每个块复制 s = n q n k s = \frac{n_q}{n_k} s=nknq 次,拼接后形成扩展后的矩阵 K ′ K' K

  • 数学等价性
    这种方法与先计算后复制数学上是等价的,但为后续的低秩分解奠定了基础。

3.2 将复制操作转移到参数侧

  • 方法描述
    不必先计算 Key 后再复制,我们可以直接在参数矩阵 (W_K) 上进行复制。将 (W_K) 按照列分为若干个小矩阵 (W_K^{(i)}),然后对每个小矩阵复制 (s) 次,拼接得到新的矩阵 (W’_K),最终直接计算 (K’ = XW’_K)。

  • 数学等价性
    这种方法与先计算后复制数学上是等价的,但为后续的低秩分解奠定了基础。

3.3 MLA 的低秩分解形式

  • SVD 分解
    论文指出,由于 W K ′ W'_K WK 仅仅是 W K W_K WK 复制而成,其自由度最多只有 n k ⋅ d h n_k \cdot d_h nkdh 个。利用奇异值分解(SVD),可以将 W K ′ W'_K WK 分解为:
    W K ′ = U K S K V K ⊤ . W'_K = U_K S_K V_K^\top. WK=UKSKVK.
    由于只有最多 n k ⋅ d h n_k \cdot d_h nkdh 个非零奇异值,可以截断 SVD,仅保留前 r ≤ n k ⋅ d h r \leq n_k \cdot d_h rnkdh 个奇异值。

  • 构造因子
    定义:
    W K a = U K [ : , : r ] S K [ : r , : r ] , W K b = S K [ : r , : r ] V K [ : r , : ] ⊤ . W^a_K = U_K[:,:r] \sqrt{S_K[:r,:r]}, \quad W^b_K = \sqrt{S_K[:r,:r]} V_K[:r,:]^\top. WKa=UK[:,:r]SK[:r,:r] ,WKb=SK[:r,:r] VK[:r,:].
    则有:
    W K ′ = W K a W K b , K ′ = X W K ′ = X W K a W K b . W'_K = W^a_K W^b_K, \quad K' = XW'_K = XW^a_K W^b_K. WK=WKaWKb,K=XWK=XWKaWKb.

  • 启示
    这表明,原本 GQA 中通过复制得到的 Key,其实可以看作是一个低秩分解的结果。相比于直接复制,MLA 在引入额外上投影( W K b W^b_K WKb)之后,可以在保持相同 KV 缓存大小的前提下,增加额外的表达能力。

3.4 MLA 的表达能力优势

  • 不可表示性
    论文讨论了一个特殊情况:如果 W K b W^b_K WKb的各个向量正交,则通过 X W K a XW^a_K XWKa得到的低维表示,在经过 W K b W^b_K WKb 扩展后,各个通道输出依然保持各自的独立性。而在 GQA 中,同一组内所有头都是完全相同的(由于复制操作),因此无法捕捉这种多样性。
  • 结论
    因此,在相同 KV 缓存大小下,MLA 能够表达出更多种类的模式,从而具有更强的表达能力。

4. TransMLA 实践:将 GQA 模型转换为 MLA 模型

论文不仅在理论上证明了 MLA 的优势,还提出了一种名为 TransMLA 的方法,将已有的 GQA‑based 模型转换为 MLA 模型。转换过程中主要的步骤包括:

  1. 参数转换

    • 对于 Q-K 对,原来 GQA 模型中用于 Key 层的矩阵 W K W_K WK被分解为 W K a W^a_K WKa W K b W^b_K WKb
    • 在转换后, W K a W^a_K WKa W V a W^a_V WVa的输出维度调整为一个较小的值(例如 512),而 KV 缓存的尺寸保持不变(例如 1024)。
    • 为了使得转换后 Query 与 Key/Value 之间依然可以充分交互, W K b W^b_K WKb W V b W^b_V WVb被设计为将低维表示扩展到一个更高的维度(例如 28×128 = 3584)。
  2. 参数增加的代价

    • 额外参数主要来自于增加的 W K b W^b_K WKb W V b W^b_V WVb 矩阵,但论文指出这部分参数只占原始参数的 1/8,对于整体模型(例如从 7.6B 增加到 7.7B)来说,增长非常有限。
  3. 后续训练与微调

    • 转换完成后,论文对转换后的模型进行进一步的训练,以提升模型的表达能力,而不会增加 KV 缓存的尺寸。这种后训练方法(post‑training)使得模型在保持低延迟的同时,能够充分利用 MLA 带来的优势。

5. 实验验证

论文在实验部分主要展示了 TransMLA 模型在下游任务上的性能提升,并与原始 GQA 模型进行了对比。

5.1 实验设置

  • 模型选择
    使用 Qwen2.5 系列模型,其中 Qwen2.5‑7B 模型有 28 个 Query 头和 4 个 Key/Value 头,每个头维度为 128,对应的 KV 缓存尺寸为 1024;而 Qwen2.5‑14B 模型相应地有更多头数和更大的 KV 缓存尺寸(2048)。

  • 转换细节
    在转换为 MLA 模型后:

    • W K a W^a_K WKa W V a W^a_V WVa 的输出维度调整为 512,
    • W K b W^b_K WKb W V b W^b_V WVb 的维度调整为 3584,
    • 总体参数量仅略有增加(例如 7.6B 增至 7.7B)。

5.2 微调与性能对比

  • 微调数据集
    使用包含数学(例如 Meta‑MathQA)和编程任务(例如 Self‑OSS‑Starcoder2‑Instruct)的指令微调数据集 SmolTalk,对比 GQA 模型与 TransMLA 模型在训练过程中的表现。

  • 实验结果

    • 训练损失:如图 2a 所示,TransMLA 模型在训练过程中损失显著降低,说明其数据拟合能力更强。
    • 测试准确率:图 2b 展示了在 7B 与 14B 模型下,TransMLA 模型在数学和编程任务上均取得了更高的准确率。
  • 消融实验
    论文还对仅通过身份映射(identity map)初始化进行维度扩展的版本进行对比,发现这种方法仅带来微小提升(例如准确率提升仅约 0.15%),从而验证了正交分解(orthogonal decomposition)在提升模型表达能力方面的关键作用。


6. 结论与未来工作

  • 结论
    论文证明了在 KV 缓存大小相同的条件下,MLA 的表达能力严格大于 GQA。理论上通过将 GQA 的复制机制转化为低秩因子分解,可以实现同样的 KV 缓存开销,但同时获得更丰富的表示能力。实验结果进一步证明了转换后的 MLA 模型在下游任务上表现更优。

  • 未来工作
    作者计划将这一方法扩展到更大规模的模型(如 LLaMA、Qwen、Mistral),并利用 DeepSeek R1 蒸馏技术进一步优化转换模型的性能,同时开发针对 MLA 的专门推理加速策略,以实现更低延迟和更高效的资源利用。


总结

整篇论文的核心贡献在于:

  1. 理论证明:展示了如何将 GQA 中的复制操作转化为低秩分解,并证明了在同等 KV 缓存开销下,MLA 的表达能力更强,尤其是在允许不同头之间产生更多差异性表示方面。
  2. 实践方案(TransMLA):提出了一种后训练方法,将现有的 GQA 模型转换为 MLA 模型,只需极少的额外参数即可显著提升模型性能。
  3. 实验验证:通过对 Qwen2.5 系列模型的微调实验,证明了 TransMLA 模型在数学、编程等任务上均优于原始 GQA 模型,验证了理论上的优势。

通过这种方法,论文为未来设计更高效且表达能力更强的注意力机制提供了新的思路,同时也为降低大模型在长序列推理时的资源消耗提出了切实可行的解决方案。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词