摘要
长上下文建模对于下一代语言模型至关重要,但标准注意力机制的高计算成本带来了巨大的计算挑战。稀疏注意力提供了一种在保持模型能力的同时提高效率的有前途的方向。本文提出了一种名为 NSA(原生可训练稀疏注意力机制) 的方法,该方法将算法创新与硬件对齐优化相结合,实现了高效的上下文建模。NSA 采用动态分层稀疏策略,结合粗粒度的标记压缩和细粒度的标记选择,在保留全局上下文感知和局部精度的同时,实现了以下两项关键创新:
- 算术强度平衡的算法设计,并针对现代硬件进行了实现优化,从而实现了大幅度的加速。
- 端到端训练支持,在不影响模型性能的情况下减少了预训练计算量。
如图 1 所示,实验表明,使用 NSA 预训练的模型在通用基准测试、长上下文任务和基于指令的推理方面均优于或与全注意力模型持平。同时,在处理 64k 长度的序列时,NSA 在解码、前向传播和后向传播的各个阶段均实现了相对于全注意力的显著加速,验证了其在模型整个生命周期中的效率。
1. 引言
研究界越来越认识到,长上下文建模是下一代大型语言模型的一项关键能力,这推动了从深度推理、代码库级代码生成到多轮自主代理系统等广泛现实世界应用的需求。最近的突破,包括 OpenAI 的 o 系列模型、DeepSeek-R1和 Gemini 1.5 Pro,使得模型能够处理整个代码库、长文档、在数千个标记上保持连贯的多轮对话,并在长程依赖关系上执行复杂推理。然而,随着序列长度的增加,vanilla Attention机制的高复杂度成为关键的延迟瓶颈。理论估计表明,在解码 64k 长度的上下文时,基于 softmax 的注意力计算占用了总延迟的 70-80%,这凸显了对更高效注意力机制的迫切需求。
高效长上下文建模的自然方法
利用 softmax 注意力固有的稀疏性是实现高效长上下文建模的一种自然方法,其中选择性地计算关键的查询-键对可以显著降低计算开销,同时保持性能。最近的进展通过多种策略证明了这一潜力:
- KV 缓存驱逐方法
- 分块 KV 缓存选择方法
- 基于采样、聚类或哈希的选择方法
尽管这些策略前景广阔,但现有的稀疏注意力方法在实际部署中往往不尽如人意。许多方法无法实现与理论收益相当的加速;此外,大多数方法主要侧重于推理阶段,缺乏有效的训练时支持来充分利用注意力的稀疏模式。
现有方法的局限性
为了解决这些局限性,有效部署稀疏注意力必须解决两个关键挑战:
- 硬件对齐的推理加速:将理论计算减少转化为实际速度提升,需要在预填充和解码阶段进行硬件友好的算法设计,以缓解内存访问和硬件调度瓶颈。
- 训练感知算法设计:支持端到端计算,并使用可训练的操作符来减少训练成本,同时保持模型性能。这些要求对于现实世界的应用实现快速长上下文推理或训练至关重要。在考虑这两个方面时,现有方法仍然存在明显差距。
本文贡献:NSA——原生可训练稀疏注意力架构
为了实现更有效和高效的稀疏注意力,本文提出了一种名为 NSA 的原生稀疏注意力架构,该架构集成了分层标记建模。
NSA 的核心创新:
- 硬件对齐的系统:针对张量核心利用率和内存访问优化了分块稀疏注意力,确保算术强度平衡。
- 训练感知设计:通过高效的算法和后向操作符实现了稳定的端到端训练。
这种优化使 NSA 能够支持高效的部署和端到端的训练。
2. 重新思考稀疏注意力方法
现代稀疏注意力方法在减少 transformer 模型理论计算复杂度方面取得了重大进展。然而,大多数方法主要在推理阶段应用稀疏性,同时保留预训练的全注意力主干,这可能会引入架构偏差,限制其充分利用稀疏注意力的优势。在介绍我们的原生稀疏架构之前,我们将通过两个关键视角系统地分析这些局限性。
2.1 推理效率的错觉
尽管在注意力计算中实现了稀疏性,但许多方法未能实现推理延迟的相应减少,主要原因有两个:
- 阶段受限的稀疏性:例如,H2O等方法在自回归解码期间应用稀疏性,但在预填充期间需要计算密集型的预处理(例如注意力图计算、索引构建)。相比之下,MInference等方法仅侧重于预填充稀疏性。这些方法无法在所有推理阶段实现加速,因为至少有一个阶段的计算成本与全注意力相当。这种阶段专业化降低了这些方法在以预填充为主的工作负载(如书籍摘要和代码完成)或以解码为主的工作负载(如长链式推理)中的加速能力。
- 与高级注意力架构不兼容:一些稀疏注意力方法无法适应现代解码高效架构,如多查询注意力 (MQA)和分组查询注意力 (GQA),它们通过在多个查询头之间共享 KV 显著减少了解码期间的内存访问瓶颈。例如,在 Quest等方法中,每个注意力头独立选择其 KV 缓存子集。虽然它在多头注意力 (MHA) 模型中表现出了一致的计算稀疏性和内存访问稀疏性,但它在基于 GQA 等架构的模型中呈现了不同的场景,其中 KV 缓存的内存访问量对应于同一 GQA 组内所有查询头的选择并集。这种架构特征意味着,虽然这些方法可以减少计算操作,但所需的 KV 缓存内存访问量仍然相对较高。这种局限性迫使一个关键的选择:虽然一些稀疏注意力方法减少了计算操作,但它们分散的内存访问模式与高级架构的高效内存访问设计相冲突。
这些局限性之所以出现,是因为许多现有的稀疏注意力方法侧重于 KV 缓存减少或理论计算减少,但难以在高级框架或后端实现显著的延迟减少。这促使我们开发结合了高级架构和硬件高效实现的算法,以充分利用稀疏性来提高模型效率。
2.2 可训练稀疏性的神话
我们对原生可训练稀疏注意力的追求源于对仅推理方法的两个关键洞察:
- 性能下降:事后应用稀疏性会迫使模型偏离其预训练优化轨迹。正如 Chen 等人所证明的,top 20% 的注意力只能覆盖 70% 的总注意力分数,这使得预训练模型中的检索头在推理期间容易受到剪枝。
- 训练效率需求:高效处理长序列训练对于现代 LLM 开发至关重要。这包括在更长的文档上进行预训练以增强模型容量,以及后续的适应阶段,如长上下文微调和强化学习。然而,现有的稀疏注意力方法主要针对推理,而将训练中的计算挑战基本未解决。这种局限性阻碍了通过高效训练开发更强大的长上下文模型。此外,将现有的稀疏注意力适应训练的努力也暴露了一些挑战:
- 不可训练组件:ClusterKV(包括 k-means 聚类)和 MagicPIG(包括基于 SimHash 的选择)中的离散操作会在计算图中产生不连续性。这些不可训练组件阻止了梯度通过标记选择过程流动,限制了模型学习最佳稀疏模式的能力。
- 后向传播效率低下:一些理论上可训练的稀疏注意力方法在实际训练中存在效率低下的问题。HashAttention等方法中使用的标记粒度选择策略导致在注意力计算期间需要从 KV 缓存中加载大量单个标记。这种非连续内存访问阻止了像 FlashAttention 这样的快速注意力技术的有效适应,这些技术依赖于连续内存访问和分块计算来实现高吞吐量。因此,实现被迫退回到低硬件利用率,这显著降低了训练效率。
2.3 原生稀疏性作为当务之急
这些在推理效率和训练可行性方面的局限性促使我们从根本上重新设计稀疏注意力机制。我们提出了一种名为 NSA 的原生稀疏注意力框架,它解决了计算效率和训练要求。在以下章节中,我们将详细介绍 NSA 的算法设计和操作符实现。
3. 方法论
我们的技术方法涵盖算法设计和内核优化。在以下小节中,我们首先介绍我们方法的背景。然后我们介绍 NSA 的整体框架,接着是其关键算法组件。最后,我们详细介绍了我们针对硬件优化的内核设计,以最大限度地提高实际效率。
3.1 背景
注意力机制在语言建模中被广泛使用,其中每个查询标记 q.t 计算与所有前面的键 k.t 的相关性得分,以生成值 v.t 的加权和。形式上,对于长度为 t 的输入序列,注意力操作定义为:
O t = Attn ( q t , K t , V t ) O_t = \text{Attn}(q_t, K_t, V_t) Ot=Attn(qt,Kt,Vt)
其中 Attn 表示注意力函数:
Attn ( q t , K t , V t ) = ∑ i = 1 t α t , i V t , i 其中 α t , i = exp ( q t ⋅ k t , i / d k ) ∑ j = 1 t exp ( q t ⋅ k t , j / d k ) \text{Attn}(q_t, K_t, V_t) = \sum_{i=1}^{t} \alpha_{t,i} V_{t,i} \quad \text{其中} \quad \alpha_{t,i} = \frac{\exp(q_t \cdot k_{t,i} / \sqrt{d_k})}{\sum_{j=1}^{t} \exp(q_t \cdot k_{t,j} / \sqrt{d_k})} Attn(qt,Kt,Vt)=i=1∑tαt,iVt,i其中αt,i=∑j=1texp(qt⋅kt,j/dk)exp(qt⋅kt,i/dk)
这里, α t , i \alpha_{t,i} αt,i 表示 q t q_t qt 和 k t , i k_{t,i} kt,i 之间的注意力权重, d k d_k dk 是键的特征维度。随着序列长度的增加,注意力计算在总体计算成本中变得越来越占主导地位,这给长上下文处理带来了重大挑战。
算术强度是计算操作与内存访问的比率。它从根本上塑造了算法在硬件上的优化。每个 GPU 都有一个关键的算术强度,由其峰值计算能力和内存带宽决定,计算为这两个硬件限制的比率。对于计算任务,高于此关键阈值的算术强度成为计算受限(受限于 GPU FLOPS),低于它则成为内存受限(受限于内存带宽)。
具体来说,对于因果自注意力机制,在训练和预填充阶段,批处理矩阵乘法和注意力计算表现出高算术强度,使这些阶段在现代加速器上成为计算受限的。相比之下,自回归解码成为内存带宽受限的,因为它每向前传递生成一个标记,同时需要加载整个键值缓存,导致算术强度低。这导致不同的优化目标——在训练和预填充期间减少计算成本,而在解码期间减少内存访问。
3.2 整体框架
为了利用具有自然稀疏模式的注意力的潜力,我们建议用一组更紧凑、信息更丰富的表示键值对 K t , V t K_t, V_t Kt,Vt 替换方程 (1) 中的原始键值对 k t , v t k_t, v_t kt,vt,对于每个查询 q t q_t qt。具体来说,我们正式定义优化后的注意力输出如下:
K t = f k ( q t , K t , V t ) , V t = f v ( q t , K t , V t ) K_t = f_k(q_t, K_t, V_t), \quad V_t = f_v(q_t, K_t, V_t) Kt=fk(qt,Kt,Vt),Vt=fv(qt,Kt,Vt)
O t = Attn ( q t , K t , V t ) O_t = \text{Attn}(q_t, K_t, V_t) Ot=Attn(qt,Kt,Vt)
其中 K t , V t K_t, V_t Kt,Vt 是基于当前查询 q t q_t qt 和上下文记忆 K t , V t K_t, V_t Kt,Vt 动态构建的。我们可以设计各种映射策略来获得不同类别的 K t , V t K_t, V_t Kt,Vt,并将它们组合如下:
O t = ∑ c ∈ C g c ⋅ Attn ( q t , K t c , V t c ) O_t = \sum_{c \in C} g_c \cdot \text{Attn}(q_t, K_t^c, V_t^c) Ot=c∈C∑gc⋅Attn(qt,Ktc,Vtc)
如图 2 所示,NSA 有三种映射策略 C = { cmp, slc, win } C = \{ \text{cmp, slc, win} \} C={cmp, slc, win},分别代表键和值的压缩、选择和滑动窗口。 g c ∈ [ 0 , 1 ] g_c \in [0,1] gc∈[0,1] 是相应策略 c c c 的门控得分,通过 MLP 和 sigmoid 激活从输入特征中导出。设 N c N_c Nc 表示重新映射的键/值的总数:
N c = ∑ c ∈ C size [ K t c ] N_c = \sum_{c \in C} \text{size}[K_t^c] Nc=c∈C∑size[Ktc]
我们通过确保 N c < t N_c < t Nc<t 来保持高稀疏率。
3.3 算法设计
在本小节中,我们将介绍我们的重新映射策略 f k f_k fk 和 f v f_v fv 的设计:标记压缩、标记选择和滑动窗口。
3.3.1 标记压缩
通过将键或值的连续块聚合到块级表示中,我们获得了压缩键和值,这些键和值捕获了整个块的信息。形式上,压缩键表示定义为:
K t cmp = MLP cmp ( [ concat i = 1 t / l [ k t , i , PE intra ( i ) ] ] ) K_t^{\text{cmp}} = \text{MLP}_{\text{cmp}} \left( \left[ \text{concat}_{i=1}^{t/l} \left[ k_{t,i}, \text{PE}_{\text{intra}}(i) \right] \right] \right) Ktcmp=MLPcmp([concati=1t/l[kt,i,PEintra(i)]])
其中 l l l 是块长度, d d d 是相邻块之间的滑动步幅, MLP cmp \text{MLP}_{\text{cmp}} MLPcmp 是一个可学习的 MLP,它具有块内位置编码,用于将块中的键映射到一个压缩键。 R t cmp ∈ R d k × t / l R_t^{\text{cmp}} \in \mathbb{R}^{d_k \times t/l} Rtcmp∈Rdk×t/l 是由压缩键组成的张量。通常,我们采用 d < l d < l d<l 来减轻信息碎片化。对于压缩值表示 V t cmp V_t^{\text{cmp}} Vtcmp 也有类似的公式。压缩表示捕获了更粗粒度的、更高级别的语义信息,并减轻了注意力的计算负担。
3.3.2 标记选择
仅使用压缩键和值可能会丢失重要的细粒度信息,这促使我们选择性地保留单个键和值。下面我们描述了我们的高效标记选择机制,该机制以低计算开销识别并保留最相关的标记。
分块选择:我们的选择策略按空间连续的块处理键和值序列,这是由两个关键因素驱动的:硬件效率考虑和注意力得分固有的分布模式。分块选择在现代 GPU 上实现高效计算至关重要。这是因为现代 GPU 架构对于连续块访问的吞吐量明显高于基于随机索引的读取。此外,分块计算能够最佳地利用张量核心。这种架构特性已使分块内存访问和计算成为高性能注意力实现中的一个基本原则,正如 FlashAttention 的基于块的设计所例证的那样。分块选择遵循注意力得分固有的分布模式。先前的工作表明,注意力得分通常表现出空间连续性,表明相邻的键往往具有相似的重要性水平。我们在第 6.2 节的可视化中也显示了这种空间连续模式。
为了实现分块选择,我们首先将键、值序列划分为选择块。为了识别对注意力计算最重要的块,我们需要为每个块分配重要性得分。下面我们介绍计算这些块级重要性得分的方法。
重要性得分计算:计算块重要性得分可能会带来显著的 overhead。幸运的是,压缩标记的注意力计算产生了中间注意力得分,我们可以利用这些得分来诱导选择块重要性得分,公式如下:
P t cmp = Softmax ( q t ⋅ K t cmp ) P_t^{\text{cmp}} = \text{Softmax}(q_t \cdot K_t^{\text{cmp}}) Ptcmp=Softmax(qt⋅Ktcmp)
其中 P t cmp ∈ R t / l P_t^{\text{cmp}} \in \mathbb{R}^{t/l} Ptcmp∈Rt/l 是 q t q_t qt 和压缩键 K t cmp K_t^{\text{cmp}} Ktcmp 之间的注意力得分。设 l ′ l' l′ 表示选择块大小。当压缩块和选择块共享相同的分块方案,即 l ′ = l = d l' = l = d l′=l=d 时,我们可以直接获得选择块重要性得分 P t slc P_t^{\text{slc}} Ptslc,公式为 P t slc = P t cmp P_t^{\text{slc}} = P_t^{\text{cmp}} Ptslc=Ptcmp。对于分块方案不同的情况,我们根据它们的空间关系推导选择块的重要性得分。给定 d ∣ l d | l d∣l 和 d ∣ l ′ d | l' d∣l′,我们有:
P t slc [ m , n ] = ∑ i = 0 l ′ − 1 ∑ j = 0 l − 1 P t cmp [ i + m ⋅ l + j ] P_t^{\text{slc}}[m,n] = \sum_{i=0}^{l'-1} \sum_{j=0}^{l-1} P_t^{\text{cmp}}[i + m \cdot l + j] Ptslc[m,n]=i=0∑l′−1j=0∑l−1Ptcmp[i+m⋅l+j]
其中 [ ⋅ ] [\cdot] [⋅] 表示用于访问向量元素的索引运算符。对于采用 GQA 或 MQA 的模型,其中键值缓存在查询头之间共享,必须确保在这些头之间的一致块选择,以最小化解码期间的 KV 缓存加载。在一个组内头之间共享的重要性得分正式定义为:
P t slc ′ = ∑ h = 1 H P t slc ( h ) P_t^{\text{slc}'} = \sum_{h=1}^{H} P_t^{\text{slc}(h)} Ptslc′=h=1∑HPtslc(h)
其中上标 ( h ) (h) (h) 表示头索引, H H H 是每个组中的查询头数。这种聚合确保了同一组内头之间的一致块选择。
Top-n 块选择:在获得选择块重要性得分后,我们保留排名靠前的 n 个稀疏块中的标记,按块重要性得分排名,公式如下:
I t = { i ∣ rank ( P t slc ′ [ i ] ) ≤ n } I_t = \{ i \mid \text{rank}(P_t^{\text{slc}'}[i]) \leq n \} It={i∣rank(Ptslc′[i])≤n}
K t slc = concat ( [ K t slc ′ [ i ⋅ l ′ : ( i + 1 ) ⋅ l ′ ] ∣ i ∈ I t ] ) K_t^{\text{slc}} = \text{concat} \left( \left[ K_t^{\text{slc}'}[i \cdot l' : (i+1) \cdot l'] \mid i \in I_t \right] \right) Ktslc=concat([Ktslc′[i⋅l′:(i+1)⋅l′]∣i∈It])
其中 rank ( ⋅ ) \text{rank}(\cdot) rank(⋅) 表示按降序排列的排名位置,排名 = 1 对应于最高得分, I t I_t It 是所选块的索引集, concat \text{concat} concat 表示连接操作。 K t slc ∈ R d k × n ⋅ l ′ K_t^{\text{slc}} \in \mathbb{R}^{d_k \times n \cdot l'} Ktslc∈Rdk×n⋅l′ 是由压缩键组成的张量。对于细粒度值 V t slc V_t^{\text{slc}} Vtslc 也有类似的公式。然后选定的键和值参与与 q t q_t qt 的注意力计算,如方程 (5) 中所定义。
3.3.3 滑动窗口
在注意力机制中,局部模式通常适应得更快,并可能主导学习过程,这可能会阻止模型有效地从压缩和选择标记中学习。为了解决这个问题,我们引入了一个专门的滑动窗口分支,它明确处理局部上下文,允许其他分支(压缩和选择)专注于学习各自的特征,而不会被局部模式所短路。具体来说,我们在窗口 w 中维护最近的标记 K t win = k t − w : t , V t win = v t − w : t K_t^{\text{win}} = k_{t-w:t}, V_t^{\text{win}} = v_{t-w:t} Ktwin=kt−w:t,Vtwin=vt−w:t,并将不同信息源的注意力计算(压缩标记、选择标记、滑动窗口)隔离到单独的分支中。然后,这些分支输出通过一个学习的门控机制聚合。为了进一步防止注意力分支之间的捷径学习,并引入最小的 overhead,我们为三个分支提供独立的键和值。这种架构设计通过防止局部和长程模式识别之间的梯度干扰,实现了稳定的学习,同时引入最少的开销。
在获得所有三种类别的键和值( K t cmp , V t cmp K_t^{\text{cmp}}, V_t^{\text{cmp}} Ktcmp,Vtcmp; K t slc , V t slc K_t^{\text{slc}}, V_t^{\text{slc}} Ktslc,Vtslc;以及 K t win , V t win K_t^{\text{win}}, V_t^{\text{win}} Ktwin,Vtwin)后,我们按照方程 (5) 计算最终的注意力输出。结合上述的压缩、选择和滑动窗口机制,这就构成了 NSA 的完整算法框架。
3.4 内核设计
为了在训练和预填充期间实现 FlashAttention 级别的加速,我们在 Triton 上实现了硬件对齐的稀疏注意力内核。鉴于 MHA 内存密集型且对解码效率低下,我们专注于具有共享 KV 缓存的架构,如 GQA 和 MQA,遵循当前最先进的 LLM 的做法。虽然压缩和滑动窗口注意力计算与现有的 FlashAttention-2 内核完全兼容,但我们引入了针对稀疏选择注意力的专用内核设计。如果我们遵循 FlashAttention 的策略,将时间连续的查询块加载到 SRAM 中,则会导致内存访问效率低下,因为块内的查询可能需要不同的 KV 块。为了解决这个问题,我们的核心优化在于不同的查询分组策略:对于查询序列上的每个位置,我们将组内的所有查询头(它们共享相同的稀疏 KV 块)加载到 SRAM 中。图 3 说明了我们的前向传递实现。所提出的内核架构具有以下关键特征:
- 以组为中心的数据加载:对于每个内部循环,加载组中位置 t 的所有头的查询 Q ∈ R h , d k Q \in \mathbb{R}^{h,d_k} Q∈Rh,dk 及其共享的稀疏键/值块索引 I t I_t It。
- 共享 KV 获取:在内部循环中,按顺序将索引为 I t I_t It 的连续键/值块加载到 SRAM 中,作为 K ∈ R B k , d k , V ∈ R B k , d v K \in \mathbb{R}^{B_k,d_k}, V \in \mathbb{R}^{B_k,d_v} K∈RBk,dk,V∈RBk,dv,以最小化内存加载,其中 B k B_k Bk 是满足 B k ∣ l ′ B_k | l' Bk∣l′ 的内核块大小。
- 网格上的外部循环:由于内部循环长度(与选定的块计数 n 成正比)对于不同的查询块几乎相同,我们将查询/输出循环放在 Triton 的网格调度器中,以简化和优化内核。
4. 实验
我们从三个方面评估 NSA:(1) 通用基准性能,(2) 长上下文基准性能,以及 (3) 链式推理性能,与全注意力基线和最先进的稀疏注意力方法进行比较。我们将稀疏计算范型的效率分析推迟到第 5 节,在那里我们将提供有关训练和推理速度的详细讨论。
4.1 预训练设置
遵循最先进的 LLM 中的常见做法,我们的实验采用了一个结合了分组查询注意力 (GQA) 和专家混合 (MoE) 的主干,特征是总共 27B 个参数,其中 3B 个参数处于活动状态。该模型由 30 层组成,隐藏维度为 2560。对于 GQA,我们将组数设置为 4,总共有 64 个注意力头。对于每个头,查询、键和值的隐藏维度分别配置为 d q = d k = 192 d_q = d_k = 192 dq=dk=192 和 d v = 128 d_v = 128 dv=128。对于 MoE,我们利用 DeepSeekMoE结构,有 72 个路由专家和 2 个共享专家,并将 top-k 专家设置为 6。为了确保训练稳定性,第一层的 MoE 被替换为 SwiGLU 形式的 MLP。
4.2 基线方法
除了与全注意力进行比较外,我们还评估了几种最先进的推理阶段稀疏注意力方法:H2O、infLLM、Quest和 Exact-Top,它首先计算完整的注意力得分,并选择每个查询对应的 top-n 得分键,然后计算这些位置上的注意力。这些方法跨越了不同的稀疏注意力范式,包括 KV 缓存驱逐、查询感知选择和精确 top-n 稀疏选择。
对于通用评估,其中大多数样本的长度在稀疏注意力基线的局部上下文窗口内,这些方法实际上等同于全注意力。因此,在这种情况下,我们仅展示 NSA 与全注意力基线的比较结果。在长上下文评估中,我们进行与所有基线方法的比较,所有稀疏注意力方法的稀疏性设置为相同,以确保公平比较。对于链式推理评估,它需要长文本监督微调,我们仅将比较限制在全注意力,因为稀疏注意力基线不支持训练。
4.3 性能比较
通用评估:我们评估了预训练的 NSA 和全注意力基线,在一系列涵盖知识、推理和编码能力的基准测试上,包括 MMLU、MMLU-PRO、CMMLU、BBH、GSM8K、MATH、DROP、MBPP和 HumanEval。结果如表 1 所示。尽管其稀疏性,NSA 实现了优异的整体性能,在 9 个指标中有 7 个超过了所有基线,包括全注意力。这表明,尽管 NSA 可能无法在较短序列上充分利用其效率优势,但它表现出强大的性能。值得注意的是,NSA 在与推理相关的基准测试中表现出显著增长(DROP:+0.042,GSM8K:+0.034),这表明我们的预训练有助于模型开发专门的注意力机制。这种稀疏注意力预训练机制迫使模型专注于最重要的信息,可能会通过过滤掉来自不相关注意力路径的噪声来提高性能。在各种评估中的持续表现也验证了 NSA 作为通用架构的稳健性。
长上下文评估:如图 5 所示,NSA 在 64k 上下文 needle-in-a-haystack测试的所有位置上都实现了完美的检索准确率。这种性能源于我们的分层稀疏注意力设计,它结合了压缩标记以实现高效的全局上下文扫描,以及选择标记以实现精确的局部信息检索。粗粒度的压缩以低成本识别相关上下文块,而对选定标记的标记级注意力确保了关键细粒度信息的保留。这种设计使 NSA 能够保持全局意识和局部精度。
我们还在 LongBench上评估了 NSA 与最先进的稀疏注意力方法和全注意力基线。为了确保稀疏性一致,我们将所有稀疏注意力基线中每个查询激活的标记数设置为 2560 个标记,这对应于 NSA 在处理 32k 序列长度时平均激活的标记数。根据 StreamLLM的说法,这个标记预算包括前 128 个标记和 512 个局部标记。我们排除了 LongBench 中的某些子集,因为所有模型的得分都较低,可能无法提供有意义的比较。如表 2 所示,NSA 实现了最高的平均得分 0.469,超过了所有基线(+0.032 超过全注意力,+0.046 超过 Exact-Top)。这种改进源于两个关键创新:(1) 我们的原生稀疏注意力设计,它能够在预训练期间对稀疏模式进行端到端优化,促进了稀疏注意力模块与其他模型组件的同步适应;(2) 分层稀疏注意力机制实现了局部和全局信息处理的平衡。
链式推理评估:为了评估 NSA 与高级下游训练范式的兼容性,我们研究了其通过后训练获得链式数学推理能力的能力。鉴于强化学习对小规模模型的有限有效性,我们采用 DeepSeek-R1 的知识蒸馏,使用 10B 标记的 32k 长度的数学推理轨迹进行监督微调 (SFT)。这产生了两个可比较的模型:全注意力-R(全注意力基线)和 NSA-R(我们的稀疏变体)。我们在具有挑战性的美国数学邀请赛 (AIME 24) 基准测试中评估了两个模型。我们使用采样温度为 0.7 和 top-p 值为 0.95 来生成每个问题的 16 个响应,并获得平均得分。为了验证推理深度的的影响,我们进行了两个生成上下文限制的实验:8192 和 16384 个标记,测量扩展推理链是否提高了准确性。附录 A 提供了模型预测的示例比较。
如表 3 所示,NSA-R 在 8192 上下文设置下实现了明显高于全注意力-R 的准确率(+0.075),并且这一优势在 16k 上下文设置下持续存在(+0.054)。这些结果验证了原生稀疏注意力的两个关键好处:(1) 预训练的稀疏注意力模式能够有效地捕获对复杂数学推导至关重要的长程逻辑依赖;(2) 我们架构的硬件对齐设计保持了足够的上下文密度,以支持不断增长的推理深度,而不会发生灾难性遗忘。不同上下文长度的持续表现优于证实了稀疏注意力在原生集成到训练管道中时对高级推理任务的可行性。
5. 效率分析
我们在 8-GPU A100 系统上评估了 NSA 与全注意力的计算效率。在效率分析中,我们还将模型配置为 GQA 组 g = 4 g=4 g=4,每组头数 h = 16 h=16 h=16,查询/键维度 d k = 192 d_k = 192 dk=192,以及值维度 d ν = 128 d_{\nu} = 128 dν=128。按照第 4 节中的相同设置,我们设置 NSA 压缩块大小 l = 32 l = 32 l=32,滑动步幅 d = 16 d = 16 d=16,选定块大小 l ′ = 64 l' = 64 l′=64,选定块计数 n = 16 n = 16 n=16,以及滑动窗口大小 w = 512 w = 512 w=512。
5.1 训练速度
我们将我们的 NSA 注意力与全注意力以及基于 Triton 的 FlashAttention-2 的 Triton 实现进行比较,以确保在相同的后端进行公平的速度比较。如图 6 所示,随着上下文长度的增加,我们的 NSA 实现了越来越多的加速,在 64k 上下文长度时,前向和后向加速高达 9.0 倍和 6.0 倍。值得注意的是,随着序列变长,速度优势变得更加明显。这种加速源于我们的硬件对齐的算法设计,以最大限度地提高稀疏注意力架构的效率:(1) 分块内存访问模式通过合并加载最大化张量核心利用率,(2) 内核中的精细循环调度消除了冗余 KV 传输。
5.2 解码速度
注意力解码速度主要由内存访问瓶颈决定,这与 KV 缓存加载量密切相关。在每个解码步骤中,我们的 NSA 只需要加载最多 ⌊ s − l d ⌋ \textstyle\left\lfloor{\frac{s-l}{d}}\right\rfloor ⌊ds−l⌋ 个压缩标记、 n l ′ n l' nl′ 个选定标记和 w w w 个邻居标记,其中 s \boldsymbol{s} s 是缓存的序列长度。如表 4 所示,随着解码长度的增加,我们的方法表现出显著的延迟减少,在 64k 上下文长度时达到高达 11.6 倍的加速。这种内存访问效率的优势也随着序列变长而放大。
6. 讨论
在本节中,我们反思了 NSA 的开发过程,并讨论了我们在探索不同稀疏注意力策略时获得的关键见解。虽然我们的方法显示出有希望的结果,但了解遇到的挑战并分析注意力模式为未来的研究方向提供了宝贵的背景。我们首先检查了替代标记选择策略的挑战,这些挑战激发了我们的设计选择,然后是可视化,这些可视化提供了对注意力分布模式的洞察。
6.1 替代标记选择策略的挑战
在设计 NSA 之前,我们探索了将现有的稀疏注意力方法适应训练阶段。然而,这些尝试遇到了各种挑战,促使我们设计了不同的稀疏注意力架构:
基于键聚类的策略:我们检查了像 ClusterKV这样的聚类策略。这些方法将来自同一聚类的键和值存储在连续的内存区域中。虽然理论上可行用于训练和推理,但它们面临三个重大挑战:(1) 动态聚类机制引入的非平凡计算 overhead;(2) 由于集群间的不平衡,加剧了操作符优化困难,尤其是在专家混合 (MoE) 系统中,其中专家并行性 (EP) 组执行时间的偏差导致持续的负载不平衡;(3) 由于需要强制性的定期重新聚类和块顺序训练协议而产生的实现约束。这些综合因素造成了巨大的瓶颈,严重限制了它们在现实世界部署中的有效性。
其他分块选择策略:我们也考虑了与 NSA 不同的分块键、值选择策略,例如 Quest和 InfLLM。这些方法依赖于计算每个块的重要性得分,并根据其与 q t q_t qt 的相似性选择 top-n 块。然而,现有方法面临两个关键问题:(1) 由于选择操作是不可微的,基于神经网络的 importance score 计算依赖于辅助损失,这增加了操作符 overhead,并经常导致模型性能下降;(2) 启发式无参数重要性得分计算策略由于低召回率,导致性能不佳。我们在一个具有类似架构的 3B 参数模型上评估了这两种方法,并将其损失曲线与 NSA 和全注意力进行比较。对于基于辅助损失的 selection 方法,我们引入额外的查询和每个块的代表性键来估计块重要性得分。这些得分由原始查询和每个块内的键之间的平均注意力得分监督。对于启发式无参数选择方法,按照 Quest 的策略,我们使用查询与键块坐标方向最小-最大值的乘积实现直接选择,而不引入额外参数。我们还探索了一种冷启动训练方法,即在过渡到启发式分块选择之前,先应用全注意力 1000 步。如图 7 所示,这两种方法都表现出较差的损失。
6.2 可视化
为了探索 transformer 注意力分布中的潜在模式,并为我们的设计寻找灵感,我们在图 8 中可视化了我们预训练的 27B 全注意力模型的注意力图。可视化揭示了有趣的模式,即注意力得分往往表现出分块聚类特征,相邻的键通常显示出相似的注意力得分。这一观察激发了 NSA 的设计,表明基于空间连续性选择键块可能是一种有前途的方法。分块聚类现象表明,序列中相邻的标记可能与查询标记共享某些语义关系,尽管这些关系的具体性质需要进一步研究。这一观察促使我们探索了一种稀疏注意力机制,它对连续标记块而不是单个标记进行操作,旨在提高计算效率并保留高注意力模式。
7. 相关工作
我们回顾了通过稀疏注意力提高注意力计算效率的现有方法。这些方法可以根据其核心策略大致分为三类:(1) 固定稀疏模式,(2) 动态标记剪枝,(3) 查询感知选择。我们介绍了每个类别中的几项代表性工作。
7.1 固定稀疏模式
滑动窗口是一种常用的方法,它允许查询仅在固定窗口内计算注意力。StreamingLLM通过维护上下文的关键部分:注意力接收器(早期标记)和局部上下文窗口,解决了处理长文本流时遇到的挑战。虽然这些方法有效地减少了内存和计算成本,但它们忽略上下文的刚性模式限制了其对需要完整上下文理解的任务的性能。
7.2 动态标记剪枝
H2O实现了一种自适应方法,在解码期间减少 KV 缓存内存使用。这种方法根据注意力得分动态地驱逐对未来的预测不太重要的标记。SnapKV还引入了标记剪枝策略,通过选择性地保留最关键的特性来减少 KV 缓存,从而实现高效的内存使用。SnapKV 通过在预填充期间进行注意力权重分析和投票来识别重要特性,然后通过将选定的压缩特性与最近上下文结合起来更新 KV 缓存,以保持提示一致性。
7.3 查询感知选择
Quest采用了一种分块选择策略,其中每个块的重要性是通过查询与键块坐标方向最小-最大值的乘积来估计的。结果得分有助于选择 top-重要的键值块进行注意力。InfLLM结合了固定模式与检索,通过维护注意力接收器、局部上下文和可检索块。这种方法从每个块中选择代表性键来估计块重要性。HashAttention通过使用学习函数将查询和键映射到 Hamming 空间,将关键标记识别公式化为推荐问题。ClusterKV通过首先聚类键,然后根据查询-聚类相似性选择最相关的集群来计算注意力,从而实现稀疏性。
8. 结论
我们提出了一种硬件对齐的稀疏注意力架构 NSA,用于高效的长上下文建模。通过在可训练架构中集成分层标记压缩和分块标记选择,我们的架构实现了加速训练和推理,同时保持了全注意力性能。NSA 通过展示通用基准性能与全注意力基线相匹配、在长上下文评估中超过建模能力以及增强的推理能力,所有这些都伴随着计算延迟的减少并实现了显著的速度提升,从而推进了最先进的技术。