欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 汽车 > 维修 > Transformers是SSMs:通过结构化状态空间对偶性的广义模型和高效算法(一)

Transformers是SSMs:通过结构化状态空间对偶性的广义模型和高效算法(一)

2024/11/30 14:44:55 来源:https://blog.csdn.net/m0_47867638/article/details/139637306  浏览:    关键词:Transformers是SSMs:通过结构化状态空间对偶性的广义模型和高效算法(一)

文章目录

  • 摘要
  • 1、引言
  • 2、背景与概述
    • 2.1、结构化状态空间模型
    • 2.2、注意力机制
    • 2.3、结构化矩阵
    • 2.4、概述:结构化状态空间对偶性
    • 2.5、符号
  • 3、状态空间模型是结构化矩阵
    • 3.1、状态空间模型的矩阵变换形式
    • 3.2、半可分离矩阵
      • 3.2.1、顺序半可分离(SSS)表示
      • 3.2.2、1-半可分矩阵:标量SSM递推
    • 3.3、状态空间模型是半可分矩阵
    • 3.4、通过结构化矩阵算法计算状态空间模型
      • 3.4.1、线性(递归)模式
      • 3.4.2、二次(朴素)模式
      • 3.4.3、总结
    • 4.1、注意力框架
      • 4.1.1、注意力
      • 4.1.2、自注意力
      • 4.1.3、核注意力
      • 4.1.4、掩码(核)注意力
    • 4.2、线性注意力
      • 4.2.1、线性注意力的张量收缩证明
    • 4.3、结构化掩码注意力
  • 5、状态空间对偶性
    • 5.1、标量-恒等结构化状态空间模型
    • 5.2、1-半可分结构化掩码注意力

摘要

链接:https://arxiv.org/pdf/2405.21060
尽管Transformer一直是深度学习在语言建模中取得成功的主要架构,但最近的研究表明,如Mamba之类的状态空间模型(SSMs)在小到中等规模上能够匹敌或超越Transformer的性能。我们表明,这两类模型实际上是非常相关的,并在一个经过充分研究的结构化半可分离矩阵类的各种分解之间,发展出SSM和注意力变体之间丰富的理论联系框架。我们的状态空间对偶性(SSD)框架使我们能够设计一种新的架构(Mamba-2),其核心层是对Mamba的选择性SSM的改进,速度提高了2-8倍,同时在语言建模方面继续与Transformer保持竞争力。

1、引言

Transformer,特别是仅解码器模型(例如GPT(Brown等人,2020年)、Llama(Touvron,Lavril等人,2023年)),这些模型以因果方式处理输入序列,是现代深度学习成功的主要驱动力之一。为了解决Transformer核心注意力层的效率问题(Tay等人,2022年),已经提出了许多方法来近似它,如训练期间序列长度的二次缩放,以及在自回归生成期间需要线性于序列长度的缓存。与此同时,一类替代的序列模型——结构化状态空间模型(SSMs)已经出现,它们在训练期间具有线性于序列长度的缩放,并在生成期间具有恒定的状态大小。这些模型在远程任务(例如S4(Gu,Goel和Ré,2022年))上表现出强大的性能,并且最近在小型到中等规模的语言建模任务上匹配或超越了Transformer(例如Mamba(Gu和Dao,2023年))。然而,SSM的发展似乎与社区为提高Transformer性能而付出的集体努力(如从理论上理解它们以及在现代硬件上优化它们)相脱节。因此,与Transformer相比,SSM更难理解和实验,从算法和系统的角度来看,训练SSM的效率仍然难以与Transformer相媲美。

我们的主要目标是建立结构化SSM和注意力变体之间丰富的理论联系。这将使我们能够将最初为Transformer开发的算法和系统优化转移到SSM上,以构建在序列长度上缩放效率更高且性能优于Transformer的基础模型为目标。这一方向的一个里程碑性贡献是线性注意力(LA)框架(Katharopoulos等人,2020年),它通过展示二次核化注意力的“对偶形式”与特定线性递归之间的等价性,得出了自回归注意力和线性RNN之间的联系。这种对偶性允许新的能力,如既能进行高效的并行训练又能进行高效的自回归推理。本着同样的精神,本文提供了多个视角,将线性复杂度的SSM与二次复杂度的形式联系起来,以结合SSM和注意力的优势。{ }^{1}

状态空间对偶性。我们将结构化SSM和注意力变体联系起来的框架,我们称之为结构化状态空间对偶性(SSD),是通过结构化矩阵的抽象来实现的:具有次二次参数和乘法复杂度的矩阵。我们开发了两个广泛的框架来表示序列模型,一个作为矩阵变换,另一个作为张量收缩,每个框架都揭示了对偶性的不同视角。我们的技术贡献包括:

  • 我们展示了状态空间模型与一种经过充分研究的结构化矩阵家族——半可分离矩阵之间的等价性(第3节)。这种联系是我们框架的核心,揭示了SSM的新属性和算法。本文的一个中心信息是,计算状态空间模型的不同方法可以重新表述为结构化矩阵上的各种矩阵乘法算法。

  • 我们显著改进了线性注意力(Katharopoulos等人,2020年)的理论。我们首先通过张量收缩的语言为其递归形式提供了一个深入的证明,然后将其推广到一个新的结构化掩码注意力(SMA)家族(第4节)。

  • 我们将SSM和SMA联系起来,展示了它们之间有一个很大的交集,它们互为对偶,同时具有类似SSM的线性形式和类似注意力的二次形式(第5节)。我们还证明了任何具有快速递归形式的核注意力方法都必须是SSM。

除了其固有的理论价值外,我们的框架为理解和改进序列模型开辟了一系列广泛的方向。

高效算法。首先且最重要的是,我们的框架为计算SSM(第6节)提供了新的高效且易于实现的算法。我们基于半可分矩阵的块分解,引入了一种新的SSD算法,该算法同时利用了线性SSM递归和二次对偶形式,在所有主要效率轴(例如,训练和推理计算、内存使用以及在现代硬件上利用矩阵乘法单元的能力)上获得了最佳权衡。SSD的专用实现比Mamba的优化选择性扫描实现快 2 − 8 2-8 28倍,同时允许使用更大的循环状态大小(是Mamba大小的 8 8 8倍或更高,且几乎没有减速)。SSD与优化的softmax注意力实现(FlashAttention-2(Dao 2024))高度竞争,在序列长度为 2 K 2 \mathrm{~K} 2 K时交叉,并在序列长度为 16 K 16 \mathrm{~K} 16 K时快 6 6 6倍。

架构设计。采用新架构(如SSM)的一个主要障碍是专为Transformer定制的生态系统,如针对大规模训练的硬件高效优化和并行技术。我们的框架允许使用已建立的注意力机制的传统和技术,为SSM构建架构设计的词汇表,并进一步改进它们(第7节)。例如,我们将多头注意力(MHA)中的头(head)的概念引入到SSM中。我们展示了Mamba架构是一个多输入SSM(MIS),这类似于多值注意力(MVA),并比较了具有不同头结构的Mamba的其他变体。

我们还利用这些想法对Mamba块进行了微小的修改,从而实现了张量并行性(例如,以Megatron(Shoeybi等人,2019年)的风格)。主要思想包括引入分组值注意力(GVA)头结构,并将所有依赖于数据的投影移至块的开始处并行执行。

通过结合修改后的并行Mamba块以及使用SSD作为内部SSM层,我们得到了Mamba-2架构。我们在与Mamba相同的设置下研究了Mamba-2的Chinchilla缩放定律,发现它在困惑度和时钟时间方面均优于Mamba和Transformer ++。此外,我们还在Pile上训练了不同大小的Mamba-2模型家族,表明它在标准下游评估中匹配或超越了Mamba和开源Transformer。例如,在Pile上训练了300B标记的具有2.7B参数的Mamba-2模型,其性能超过了在相同数据集上训练的Mamba-2.8B、Pythia-2.8B以及Pythia-6.9B。

系统优化。SSD框架将SSM和Transformer连接起来,使我们能够利用为Transformer开发的丰富系统优化工作(第8节)。

  • 例如,张量并行性(TP)是一种重要的模型并行技术,用于通过在同一节点上的GPU之间拆分每一层来训练大型Transformer模型。我们设计了Mamba-2以使其支持TP,从而将每个块的同步点数量减少了一半。
  • 对于激活不适合单个设备的非常长的序列,已经为注意力块开发了序列并行性。我们描述了如何通过在不同设备之间传递循环状态来训练SSM,特别是Mamba-2,以实现序列并行性。
  • 对于具有不同长度示例的微调,为了提高效率,Transformer需要复杂的技术来删除填充令牌并在可变长度序列上执行注意力。我们展示了Mamba-2如何能够高效地以可变序列长度进行训练,而无需填充令牌。

第9节通过语言建模、训练效率和困难的多查询关联召回任务(Arora, Eyuboglu, Zhang等人,2024年)对Mamba-2进行了实证验证。最后,在第10节中,我们提供了扩展的相关工作,并讨论了我们的框架开启的潜在研究方向。

模型代码和预训练检查点已在https://github.com/state-spaces/mamba上开源。

2、背景与概述

2.1、结构化状态空间模型

结构化状态空间序列模型(S4)是深度学习领域的一类新型序列模型,与RNNs、CNNs以及经典的状态空间模型有广泛的相关性。它们受到特定连续系统的启发,该系统通过一个隐含的潜状态 h ∈ R ( ⊤ , N ) h \in \mathbb{R}^{(\top, N)} hR(,N) 将一维序列 x ∈ R ⊤ x \in \mathbb{R}^{\top} xR 映射到 y ∈ R ⊤ y \in \mathbb{R}^{\top} yR

结构化SSM的一般离散形式可以表示为方程(1)。

在这里插入图片描述

其中 A ∈ R ( N , N ) A \in \mathbb{R}^{(N, N)} AR(N,N) B ∈ R ( N , 1 ) B \in \mathbb{R}^{(N, 1)} BR(N,1) C ∈ R ( N , 1 ) C \in \mathbb{R}^{(N, 1)} CR(N,1)。结构化SSM之所以得名,是因为控制时间动态的 A A A 矩阵必须结构化,以便这种序列到序列的转换能够足够高效地计算,从而可以在深度神经网络中使用。最初引入的结构包括对角加低秩(DPLR)(Gu, Goel, and Ré 2022)和对角结构(Gu, Gupta, et al. 2022; Gupta, Gu, and Berant 2022; J. T. Smith, Warrington, and Linderman 2023),其中对角结构仍然是最流行的结构。

在这项工作中,我们使用“状态空间模型”(SSM)这一术语来指代结构化SSM。这类SSM有许多变种,与神经序列模型的几个主要范式如连续时间、递归和卷积模型(Gu, Johnson, Goel等人,2021)有深厚的联系。我们下面简要概述一下,并参考以前的工作以获取更多上下文和细节(Gu 2023;Gu 和 Dao 2023)。

连续时间模型。最初的结构化SSM起源于连续时间映射函数 x ( t ) ∈ R ↦ y ( t ) ∈ R x(t) \in \mathbb{R} \mapsto y(t) \in \mathbb{R} x(t)Ry(t)R,而不是直接在序列上操作。在连续时间视角下,方程(1a)中的矩阵 ( A , B ) (A, B) (A,B)不是直接学习的,而是从底层参数 ( A ˙ , B ˙ ) (\dot{A}, \dot{B}) (A˙,B˙)以及参数化步长 Δ \Delta Δ生成的。通过固定公式 A = f A ( Δ , A ˚ ) A=f_{A}(\Delta, \AA) A=fA(Δ,A˚) B = f B ( Δ , B ) B=f_{B}(\Delta, B) B=fB(Δ,B),将“连续参数” ( Δ , A ˚ , A ˙ , B ˙ ) (\Delta, \AA, \dot{A}, \dot{B}) (Δ,A˚,A˙,B˙)转换为“离散参数” ( A , B ) (A, B) (A,B),其中对 ( f A , f B ) (f_{A}, f_{B}) (fA,fB)称为离散化规则。

备注1。虽然我们的主要模型采用了与之前工作相同的参数化和离散化步骤(详见Gu和Dao(2023)),但为了简化阐述和符号表示,我们在本文的其余部分省略了这一步骤。我们注意到,先前关于结构化SSM的工作将连续参数 ( A ˙ , B ˙ ) (\dot{A}, \dot{B}) (A˙,B˙)和离散参数 ( A , B ) (A, B) (A,B)分别称为 ( A , B ) (A, B) (A,B) ( A ˉ , B ˉ ) (\bar{A}, \bar{B}) (Aˉ,Bˉ);我们更改了符号以简化呈现,并直接关注于控制SSM主要递归的离散参数。

递归模型。方程(1)和(2)采用了在输入x上线性的递归形式。因此,结构化SSM可以看作是一种递归神经网络(RNN),其线性特性赋予它们额外的属性,并允许它们避免传统RNN的顺序计算。反之,尽管进行了这种简化,但SSM作为序列转换仍然具有完全的表达能力(在通用近似的意义上)(Kaul 2020;Orvieto等人,2023;Shida Wang和Xue,2023)。

卷积模型。当SSM的动力学特性如方程(1)所示在时间上是恒定时,该模型被称为线性时不变(LTI)。在这种情况下,它们等同于卷积。因此,SSM也可以看作是CNN的一种类型,但其中(i)卷积核是通过SSM参数(A,B,C)隐式参数化的,并且(ii)卷积核通常是全局的而非局部的。相反,通过经典的信号处理理论,所有表现良好的卷积都可以表示为SSM。

通常,以前的LTI SSM会使用卷积模式进行高效的并行化训练(在此模式下,可以提前看到整个输入序列),并切换到递归模式(1)以进行高效的自回归推理(在此模式下,每次只看到一个步骤的输入)。

选择性状态空间模型。在Mamba中引入的形式(2),其中参数 ( A , B , C ) (A, B, C) (A,B,C)也可以随时间变化,被称为选择性SSM。与标准的LTI形式(1)相比,该模型可以选择在每个时间步上关注或忽略输入。它在诸如语言这样的信息密集数据上表现得比LTI SSM好得多,尤其是当其状态大小 N \mathrm{N} N增加以允许更多的信息容量时。然而,它只能在递归模式下计算,而不是卷积模式,并且需要仔细考虑硬件感知的实现才能提高效率。尽管如此,它仍然不如对硬件友好的模型(如CNN和Transformer)高效,因为它没有利用矩阵乘法单元,而现代加速器(如GPU和TPU)正是为此而设计的。

虽然时不变SSM与连续、递归和卷积序列模型密切相关,但它们与注意力机制没有直接关系。在本文中,我们展示了选择性SSM与注意力之间的更深层次关系,并利用这种关系显著提高SSM的训练速度,同时允许更大的状态大小 N \mathrm{N} N

结构化SSM作为序列变换

定义2.1. 我们使用术语“序列变换”来指代一个参数化的映射函数 Y = f θ ( X ) Y=f_{\theta}(X) Y=fθ(X),其中 X , Y ∈ R ( T , P ) X, Y \in \mathbb{R}^{(\mathrm{T}, \mathrm{P})} X,YR(T,P),且 θ \theta θ是任意一组参数的集合。 T \mathrm{T} T表示序列或时间轴;下标索引是第一维的索引,例如 X t , Y t ∈ R P X_{t}, Y_{t} \in \mathbb{R}^{\mathrm{P}} Xt,YtRP

序列变换(例如SSM或自注意力机制)是深度序列模型的基础,它们被整合到神经网络架构中(例如Transformer)。在(1)或(2)中的SSM是一个 P = 1 \mathrm{P}=1 P=1的序列变换;它可以通过简单地在这个维度上进行广播(换句话说,将输入视为 P \mathrm{P} P个独立的序列,并对每个序列应用SSM)来推广到 P > 1 \mathrm{P}>1 P>1。可以将 P \mathrm{P} P视为头维度,我们将在第7节中详细阐述。

定义2.2. 我们定义SSM操作符 SSM ⁡ ( A , B , C ) = SSM ⁡ ( A 0 : T , B 0 : T , C 0 : T ) \operatorname{SSM}(A, B, C)=\operatorname{SSM}\left(A_{0: T}, B_{0: T}, C_{0: T}\right) SSM(A,B,C)=SSM(A0:T,B0:T,C0:T)为通过方程(2)定义的序列变换 X ∈ R ( T , P ) ↦ Y ∈ R ( T , P ) X \in \mathbb{R}^{(T, P)} \mapsto Y \in \mathbb{R}^{(T, P)} XR(T,P)YR(T,P)

在SSM中, N \mathrm{N} N维度是一个自由参数,称为状态大小或状态维度。我们也称它为状态扩展因子,因为它将输入/输出的大小扩展了 N \mathrm{N} N倍,这对这些模型的计算效率有影响。

最后,我们注意到许多类型的序列变换,如注意力机制,都可以表示为跨序列维度的单一矩阵乘法。

定义2.3. 如果序列变换 Y = f θ ( X ) Y=f_{\theta}(X) Y=fθ(X)可以写成 Y = M θ X Y=M_{\theta} X Y=MθX的形式,其中 M M M是一个依赖于参数 θ \theta θ的矩阵,我们称该序列变换为矩阵变换。我们将序列变换与矩阵 M M M等同起来,并在上下文清晰时省略对 θ \theta θ的依赖。

2.2、注意力机制

注意力机制泛指一种计算方式,它给序列中每对位置分配分数,允许每个元素“关注”其他元素。到目前为止,最常见且最重要的注意力变体是softmax自注意力机制,它可以定义为

Y = softmax ⁡ ( Q K ⊤ ) ⋅ V Y=\operatorname{softmax}\left(Q K^{\top}\right) \cdot V Y=softmax(QK)V

其中 Q , K , V ∈ R ( T , P ) Q, K, V \in \mathbb{R}^{(\mathrm{T}, \mathrm{P})} Q,K,VR(T,P)。通过实现 Q K ⊤ Q K^{\top} QK进行的成对比较机制导致了注意力机制特有的二次训练成本。

已经提出了许多注意力机制的变体,但它们都共享这些注意力得分的核心原理,以及各种近似方法(Tay et al. 2022)。对于本文工作来说,最重要的变体是线性注意力(Katharopoulos et al. 2020)。粗略地说,这一类方法通过将softmax折叠到核特征映射中来省略它,并利用矩阵乘法的结合律来重写 ( Q K ⊤ ) ⋅ V = Q ⋅ ( K ⊤ V ) \left(Q K^{\top}\right) \cdot V = Q \cdot \left(K^{\top} V\right) (QK)V=Q(KV)。此外,在重要的因果(自回归)注意力情况下,他们展示了当因果掩码被整合到左侧作为 ( L ∘ Q K ⊤ ) ⋅ V \left(L \circ Q K^{\top}\right) \cdot V (LQK)V时,其中 L L L是下三角全1矩阵,那么右侧可以扩展为一个递归。一些最近和并发的工作,如RetNet(Y. Sun et al. 2023)和GateLoop(Katsch 2023),将这一思想推广到了更一般的 L L L形式(第10节)。在这项工作中,我们对结构化掩码注意力的表述将极大地推广这些思想。

2.3、结构化矩阵

一般矩阵 M ∈ R ( T , T ) M \in \mathbb{R}^{(\mathrm{T}, \mathrm{T})} MR(T,T)需要 T 2 \mathrm{T}^{2} T2个参数来表示,并且执行基本操作(如矩阵-向量乘法)需要 O ( T 2 ) O\left(\mathrm{~T}^{2}\right) O( T2)的时间。结构化矩阵是那些
(i) 通过压缩表示可以用次二次(理想情况下是线性)参数来表示的,
(ii) 直接在这个压缩表示上进行操作具有快速算法(最重要的是矩阵乘法)的矩阵。
最典型的结构化矩阵族可能是稀疏矩阵和低秩矩阵。然而,还存在许多其他类型的矩阵族,如Toeplitz矩阵、Cauchy矩阵、Vandermonde矩阵和butterfly矩阵,它们都在机器学习中被用于构建高效的模型(Dao, Gu, 等人 2019;D. Fu 等人 2024;Gu, Gupta, 等人 2022;Thomas 等人 2018)。结构化矩阵是高效表示和算法的有力抽象。在这项工作中,我们将展示SSM(结构化状态空间模型)等价于另一类在深度学习中之前未使用的结构化矩阵,并利用这一联系推导出高效的方法和算法。

2.4、概述:结构化状态空间对偶性

虽然本文在SSM、注意力和结构化矩阵之间建立了更丰富的联系框架,但我们简要概述了主要方法,该方法在算法上其实相当独立且简单。

递归(线性)形式。状态空间对偶(SSD)层可以定义为选择性SSM(2)的一个特例。SSM的标准计算可以作为一个递归(或并行扫描)应用,其复杂度在序列长度上线性增长。与Mamba中使用的版本相比,SSD有两个细微的差异:

  • 矩阵 A A A的结构从对角结构进一步简化为标量乘以恒等结构。在这种情况下,每个 A t A_t At也可以仅用一个标量来标识。
  • 我们使用较大的头维度 P P P,与Mamba中使用的 P = 1 P=1 P=1相比。通常选择 P = { 64 , 128 } P=\{64,128\} P={64,128},这与现代Transformer的惯例相似。

与原始的选择性SSM相比,这些变化可以看作是稍微降低了表达能力,但换来了显著的训练效率提升。特别是,我们的新算法将允许在现代加速器上使用矩阵乘法单元。

对偶(二次)形式。SSD的对偶形式是一个与注意力密切相关的二次计算,定义为

( L ∘ Q K ⊤ ) ⋅ V L i j = { a i × ⋯ × a j + 1 i ≥ j 0 i < j \left(L \circ Q K^{\top}\right) \cdot V \quad L_{i j}=\left\{\begin{array}{ll} a_{i} \times \cdots \times a_{j+1} & i \geq j \\ 0 & i<j \end{array}\right. (LQK)VLij={ai××aj+10iji<j

其中 a i a_i ai是依赖于输入的标量,其值在 [ 0 , 1 ] [0,1] [0,1]之间。

与标准的softmax注意力相比,主要有两个主要差异:

  • 删除了softmax。
  • 注意力矩阵逐元素地与一个额外的掩码矩阵 L L L相乘。

这两个变化都可以看作是解决标准注意力中的问题。例如,最近观察到softmax在注意力分数中造成问题,如“注意力下沉”现象(Darcet et al. 2024; Xiao et al. 2024)。更重要的是,掩码矩阵 L L L可以看作是用一个不同的数据依赖的位置掩码替换Transformer的启发式位置嵌入,这个掩码控制信息在时间上的传递量。

更广泛地说,这种形式是我们在线性注意力中定义的结构化掩码注意力泛化的一个实例,具体定义见第4节。

矩阵形式和SSD算法。通过统一的矩阵表示,SSD的各种形式相互关联,表明SSM具有矩阵变换形式 Y = M X Y=M X Y=MX,其中矩阵 M θ ∈ R ( T , T ) M_{\theta} \in \mathbb{R}^{(T, \mathrm{~T})} MθR(T, T)依赖于 θ = ( A , B , C ) \theta=(A, B, C) θ=(A,B,C)。特别是,SSD的对偶形式等价于矩阵 M M M的朴素(二次时间)乘法,而递归形式是一种利用 M M M中结构的特定高效(线性时间)算法。

除了这些之外,任何用于与 M M M相乘的算法都可以应用。我们提出的硬件高效SSD算法(第6节)是一种新的结构化矩阵乘法方法,它涉及 M M M的块分解,与纯线性或二次形式相比,获得了更好的效率权衡。与一般的选择性SSM(Gu和Dao 2023)相比,它相对简单且易于实现;列表1提供了几行代码的完整实现。

图1提供了本文中呈现的概念之间关系的简单路线图。
在这里插入图片描述

2.5、符号

在本文中,我们倾向于使用可以映射到代码的精确符号。

矩阵和向量。我们通常使用小写字母表示向量(即具有单个轴的张量),使用大写字母表示矩阵(即具有多于一个轴的张量)。在这项工作中,我们不使用粗体表示矩阵。有时,如果矩阵沿着一个轴绑定或重复(因此也可以视为向量),我们可能会使用大写或小写来表示它。 2 {}^{2} 2 - 表示标量或矩阵乘法,而 ∘ \circ 表示Hadamard(逐元素)乘法。

索引。我们使用Python风格的索引,例如i:j指的是当i<j时的范围(i, i+1, ..., j-1),而当i>j时的范围(i, i-1, ..., j+1)。例如,对于任何符号v,我们让v_{j:i}j ≥ i时表示序列(v_{j}, ..., v_{i+1})[i]等价于0:i=(0, ..., i-1)。为了简化,我们还让v_{j:i}^{\times}表示乘积v_{j} \times ... \times v_{i+1}

维度。为了与矩阵和张量区分开来,我们经常使用打字机字体的大写字母(如D, N, T)来表示维度和张量形状。与传统记法M \in \mathbb{R}^{T \times T}不同,我们经常使用M \in \mathbb{R}^{(\mathrm{T}, \mathrm{T})}来反映代码中的张量形状。

张量收缩。我们将大量依赖张量收缩或einsum记法,这既是为了清晰性,也是作为陈述和证明我们结果的核心工具。我们假设读者熟悉这种记法,它在现代张量库(如numpy)中广泛使用。例如,我们可以使用收缩(M N, N K → M K)来表示矩阵-矩阵乘法运算符,而在我们的记法中,收缩(\mathrm{MN}, \mathrm{NK} → \mathrm{MK})(X, Y)(这等价于X \cdot Y)可以翻译为numpy代码numpy.einsum('mn, nk -> mk', X, Y)

附录A中包含了一个大量的符号表。

3、状态空间模型是结构化矩阵

本节探讨了状态空间模型作为序列变换的不同视角,并概述了此类映射的性质和算法。本节的主要结果是状态空间模型与一类称为半可分离矩阵的结构化矩阵之间的等价性,这暗示了新的效率结果(定理3.5和3.7)。

3.1、状态空间模型的矩阵变换形式

回顾一下,我们对SSM的定义是通过(2)定义的参数化映射。我们的理论框架首先通过简单地将这个变换写成一个矩阵乘法映射,将向量 x ∈ R ⊤ x \in \mathbb{R}^{\top} xR映射到 y ∈ R ⊤ y \in \mathbb{R}^{\top} yR

根据定义, h 0 = B 0 x 0 h_{0}=B_{0} x_{0} h0=B0x0。通过归纳,

h t = A t … A 1 B 0 x 0 + A t … A 2 B 1 x 1 + ⋯ + A t A t − 1 B t − 2 x t − 2 + A t B t − 1 x t − 1 + B t x t = ∑ s = 0 t A t : s × B s x s . \begin{aligned} h_{t} & =A_{t} \ldots A_{1} B_{0} x_{0}+A_{t} \ldots A_{2} B_{1} x_{1}+\cdots+A_{t} A_{t-1} B_{t-2} x_{t-2}+A_{t} B_{t-1} x_{t-1}+B_{t} x_{t} \\ & =\sum_{s=0}^{t} A_{t: s}^{\times} B_{s} x_{s} . \end{aligned} ht=AtA1B0x0+AtA2B1x1++AtAt1Bt2xt2+AtBt1xt1+Btxt=s=0tAt:s×Bsxs.

乘以 C t C_{t} Ct来产生 y t y_{t} yt,并在 t ∈ [ T ] t \in[\mathrm{T}] t[T]上对方程进行矢量化,我们得到了SSM的矩阵变换形式。

y t = ∑ s = 0 t C t ⊤ A t : s × B s x s y = SSM ⁡ ( A , B , C ) ( x ) = M x M j i : = C j ⊤ A j ⋯ A i + 1 B i \begin{aligned} y_{t} & =\sum_{s=0}^{t} C_{t}^{\top} A_{t: s}^{\times} B_{s} x_{s} \\ y & =\operatorname{SSM}(A, B, C)(x)=M x \\ M_{j i} & :=C_{j}^{\top} A_{j} \cdots A_{i+1} B_{i} \end{aligned} ytyMji=s=0tCtAt:s×Bsxs=SSM(A,B,C)(x)=Mx:=CjAjAi+1Bi

3.2、半可分离矩阵

在等式(3)中的 M M M是称为半可分离矩阵的一类矩阵的特定表示。半可分离矩阵是一种基本的矩阵结构。我们首先定义这些矩阵及其性质。

定义3.1.(下三角)矩阵 M M M N \mathrm{N} N-半可分离的,如果其下三角部分(即对角线上或下方的部分)包含的每个子矩阵的秩最多为 N \mathrm{N} N。我们称 N \mathrm{N} N为半可分离矩阵的阶数或秩。

定义3.1,以及其他形式的相关“可分离”结构(例如准可分离矩阵和半可分离矩阵的其他定义)有时被称为结构化秩矩阵(或秩结构化矩阵),因为它们由其子矩阵的秩条件来表征。半可分离矩阵有许多结构化表示形式,包括分层半可分离(HSS)、顺序半可分离(SSS)和Bruhat形式(Pernet和Storjohann 2018)。我们主要使用SSS形式。

3.2.1、顺序半可分离(SSS)表示

定义3.2. 一个下三角矩阵 M ∈ R ( T , T ) M \in \mathbb{R}^{(\mathrm{T}, \mathrm{T})} MR(T,T)有一个 N \mathrm{N} N-顺序半可分离(SSS)表示,如果它可以写成以下形式

M j i = C j ⊤ A j ⋯ A i + 1 B i M_{j i}=C_{j}^{\top} A_{j} \cdots A_{i+1} B_{i} Mji=CjAjAi+1Bi

其中,向量 B 0 , … , B T − 1 , C 0 , … , C T − 1 ∈ R N B_{0}, \ldots, B_{\mathrm{T}-1}, C_{0}, \ldots, C_{\mathrm{T}-1} \in \mathbb{R}^{\mathrm{N}} B0,,BT1,C0,,CT1RN和矩阵 A 0 , … , A ⊤ − 1 ∈ R ( N , N ) A_{0}, \ldots, A_{\top-1} \in \mathbb{R}^{(\mathbb{N}, \mathrm{N})} A0,,A1R(N,N)
我们定义操作符 SSS ⁡ \operatorname{SSS} SSS,使得 M = SSS ⁡ ( A 0 : T , B 0 : T , C 0 : T ) M=\operatorname{SSS}\left(A_{0: \mathrm{T}}, B_{0: \mathrm{T}}, C_{0: \mathrm{T}}\right) M=SSS(A0:T,B0:T,C0:T)

半可分离矩阵的一个基本结果是它们与具有SSS表示的矩阵完全等价。一个方向可以通过一个简单的构造性证明来推导。

引理3.3. 一个 N \mathrm{N} N-SSS矩阵 M M M,其表示形式为(4),是 N \mathrm{N} N-半可分离的。

证明。考虑任何非对角块 M j : j ′ , i ′ : i M_{j: j^{\prime}, i^{\prime}: i} Mj:j,i:i,其中 j ′ > j ≥ i > i ′ j^{\prime}>j \geq i>i^{\prime} j>ji>i。这个块有一个明确的秩为 N N N的分解,表示为

(方程(5)将在推导我们序列模型的快速算法时广泛使用。另一个方向在关于半可分离矩阵的文献中已经得到很好的建立。)

命题3.4. 每一个 N \mathrm{N} N-半可分离矩阵都有一个 N \mathrm{N} N-SSS表示。

此外,请注意,尽管定义3.2中的表示涉及 O ( N 2 T ) O\left(\mathrm{~N}^{2} \mathrm{~T}\right) O( N2 T)个参数(特别是用于存储 A A A矩阵),但实际上它可以压缩到 O ( N T ) O(\mathrm{NT}) O(NT)个参数,这在渐近意义上是紧的(Pernet, Signargout, 和 Villard 2023)。因此,在本文的其余部分,我们将把结构化矩阵类(定义3.1)和它的一个特定表示(定义3.2)混为一谈;我们将始终使用这种表示而不是其他候选者。反过来,我们将使用 N \mathrm{N} N-SS来指代SSS形式的 N \mathrm{N} N-半可分离矩阵。

半可分离矩阵是一种基本的矩阵结构,具有许多重要的性质。它们与广泛的递推密切相关,并可以通过多种表征(例如定义3.1和3.2)来定义,这些表征揭示了它们之间的不同联系和高效算法。我们在附录C.1中提到了一些它们的其他性质。

备注2. 半可分性的概念非常广泛,文献中出现了许多相似但微妙不同的定义;我们的定义可能与其他约定略有不同。首先,由于本文主要关注因果或自回归设置,我们将半可分性的定义限制为三角情况;定义3.1在一些作者那里可能更正式地被称为( N , 0 \mathrm{N}, 0 N,0)-半可分性。一些作者也可能将其称为准可分性的一种形式(Eidelman 和 Gohberg 1999;Pernet 2016)。关于简短概述,请参阅Vandebril等人(2005)。

3.2.2、1-半可分矩阵:标量SSM递推

我们将单独讨论1-SS矩阵的特殊情况。注意,在这种情况下, C j C_{j} Cj B i B_{i} Bi是标量,并且可以从SSS表示(4)中分离出来(我们也使用小写字母来强调在这种情况下参数是标量)

SSS ⁡ ( a , b , c ) = diag ⁡ ( c ) ⋅ M ⋅ diag ⁡ ( b ) 其中 M j i = a j : i × ⋅ \operatorname{SSS}(a, b, c)=\operatorname{diag}(c) \cdot M \cdot \operatorname{diag}(b) \quad \text{其中} \quad M_{j i}=a_{j: i}^{\times} \cdot SSS(a,b,c)=diag(c)Mdiag(b)其中Mji=aj:i×

由于对角矩阵易于处理(例如,对角矩阵的乘法等同于逐元素的标量乘法),我们可以忽略这些项。因此,1-SS矩阵的基本表示形式为 M j i = a j : i M_{j i}=a_{j: i} Mji=aj:i

M = SSS ⁡ ( a 0 : T ) : = [ 1 a 1 1 a 2 a 1 a 2 1 ⋮ ⋮ ⋱ ⋱ a T − 1 … a 1 a T − 1 … a 2 … a T − 1 1 ] M=\operatorname{SSS}\left(a_{0: T}\right):=\left[\begin{array}{ccccc} 1 & & & & \\ a_{1} & 1 & & & \\ a_{2} a_{1} & a_{2} & 1 & & \\ \vdots & \vdots & \ddots & \ddots & \\ a_{T-1} \ldots a_{1} & a_{T-1} \ldots a_{2} & \ldots & a_{T-1} & 1 \end{array}\right] M=SSS(a0:T):= 1a1a2a1aT1a11a2aT1a21aT11

1-SS矩阵的重要性在于它们等价于标量递归的最小形式——状态维度 N = 1 \mathrm{N}=1 N=1且没有 ( B , C ) (B, C) (B,C)投影的SSM退化情况。注意,乘法 y = M x y=M x y=Mx可以通过递归来计算:

y t = a t : 0 x 0 + ⋯ + a t : t x t = a t ( a t − 1 : 0 x 0 + ⋯ + a t − 1 : t − 1 x t − 1 ) + a t : t x t = a t y t − 1 + x t \begin{aligned} y_{t} & =a_{t: 0} x_{0}+\cdots+a_{t: t} x_{t} \\ & =a_{t}\left(a_{t-1: 0} x_{0}+\cdots+a_{t-1: t-1} x_{t-1}\right)+a_{t: t} x_{t} \\ & =a_{t} y_{t-1}+x_{t} \end{aligned} yt=at:0x0++at:txt=at(at1:0x0++at1:t1xt1)+at:txt=atyt1+xt

因此,我们也称1-SS矩阵的矩阵乘法为标量SSM递归或cumprodsum(累积乘积和;累积乘积和累积和的泛化)算子。作为递归的基本形式,1-SS矩阵的乘法对于我们的主要算法来说是一个重要的构建块。

我们强调,本文的一个中心主题是许多序列模型上的算法都可以简化为结构化矩阵乘法算法。1-SS矩阵就是这一联系的例证:有许多快速算法用于计算基本的标量递归或cumprodsum算子,所有这些算法都等价于1-SS矩阵的不同结构化分解。我们在附录B中专门讨论了1-SS矩阵乘法的这些算法。

3.3、状态空间模型是半可分矩阵

回顾一下,我们定义的状态空间模型(SSM)是通过定义2.1定义的参数化映射。SSM与半可分矩阵之间的联系来自于简单地将这种变换表示为矩阵乘法,映射向量 x x x y ∈ R ⊤ y \in \mathbb{R}^{\top} yR。方程(3)直接建立了状态空间模型与顺序半可分表示之间的联系,而顺序半可分表示又等价于一般的半可分矩阵(引理3.3和命题3.4)。

定理3.5. 状态空间模型变换 y = SSM ⁡ ( A , B , C ) ( x ) y=\operatorname{SSM}(A, B, C)(x) y=SSM(A,B,C)(x) 在状态大小为 N \mathrm{N} N 的情况下,等同于顺序半可分表示下的 N \mathrm{N} N -SS 矩阵乘法 y = SSS ⁡ ( A , B , C ) ⋅ x y=\operatorname{SSS}(A, B, C) \cdot x y=SSS(A,B,C)x

换句话说,序列变换算子SSM(定义2.2)与矩阵构造算子SSS(定义3.2)是相吻合的,我们交替使用它们(或有时用SS作为简写)。此外,命运的巧合使得结构化状态空间模型和顺序半可分矩阵具有相同的缩写,这强调了它们的等价性!为了方便起见,我们可以交替使用这些缩写SSM(状态空间模型或半可分矩阵)、SSS(结构化状态空间或顺序半可分)或SS(状态空间或半可分),以明确指代任一概念。然而,我们通常会按照惯例,SSM指状态空间模型,SS指半可分,SSS指顺序半可分。

图2展示了状态空间模型作为半可分矩阵的序列变换视角。
在这里插入图片描述

3.4、通过结构化矩阵算法计算状态空间模型

定理3.5之所以重要,是因为它将允许我们将SSM(以及其他序列模型)的高效计算问题简化为结构化矩阵乘法的高效算法。在展示SSM与其他序列模型的等价性(第4节和第5节)之后,我们简要概述一下,并将我们的主要新算法推迟到第6节。

如前所述,半可分矩阵(即秩结构矩阵)是一种经典的结构化矩阵类型:
(i) 它们具有压缩表示形式,如SSS形式,该形式仅有 O ( T ) O(\mathrm{~T}) O( T)而不是 O ( T 2 ) O\left(\mathrm{~T}^{2}\right) O( T2)个参数。
(ii) 它们有直接在压缩表示上操作的快速算法。

此外,参数化和矩阵乘法成本可以在半可分顺序中达到紧凑。

命题3.6 (Pernet, Signargout, 和 Villard (2023)). 一个大小为 T \mathrm{T} T N \mathrm{N} N-SS矩阵可以用 O ( N T ) O(\mathrm{NT}) O(NT)个参数表示,并且在时间和空间上都具有 O ( N T ) O(\mathrm{NT}) O(NT)的矩阵-向量乘法成本。

例如,1-SS矩阵说明了这种联系的本质。矩阵 M = 1 S S ( a ) M=1\mathrm{SS}(a) M=1SS(a)正是由恰好 T − 1 \mathrm{T}-1 T1个参数 a 0 : T − 1 = a 1 , … , a T − 1 a_{0:\mathrm{T}-1}=a_{1}, \ldots, a_{\mathrm{T}-1} a0:T1=a1,,aT1定义的,并且可以通过遵循标量递归(7)在 O ( T ) O(\mathrm{~T}) O( T)时间内计算。

3.4.1、线性(递归)模式

命题3.6在具有对角结构的状态空间模型(S4D,Gu, Gupta等人,2022)的情况下可以很容易地看到,只需利用状态空间模型公式(2)并展开递归即可。我们在(8)中提供了正式的张量收缩算法,其中维度 S \mathrm{S} S等于 T 4 \mathrm{T}^{4} T4

Z = contract  ( S P , S N → S P N ) ( X , B ) ( S , P , N ) H = contract  ( T S N , S P N → T P N ) ( L , Z ) ( T , P , N ) Y = contract  ( T N , T P N → T P ) ( C , H ) ( T , P ) \begin{array}{rlr} Z & =\text{contract }(\mathrm{SP}, \mathrm{SN} \rightarrow \mathrm{SPN})(X, B) & (\mathrm{S}, \mathrm{P}, \mathrm{N}) \\ H & =\text{contract }(\mathrm{TSN}, \mathrm{SPN} \rightarrow \mathrm{TPN})(L, Z) & (\mathrm{T}, \mathrm{P}, \mathrm{N}) \\ Y & =\text{contract }(\mathrm{TN}, \mathrm{TPN} \rightarrow \mathrm{TP})(C, H) & (\mathrm{T}, \mathrm{P}) \end{array} ZHY=contract (SP,SNSPN)(X,B)=contract (TSN,SPNTPN)(L,Z)=contract (TN,TPNTP)(C,H)(S,P,N)(T,P,N)(T,P)

在这里, L ∈ R ( T , T ) L \in \mathbb{R}^{(\mathrm{T}, \mathrm{T})} LR(T,T)被定义为 1 S S ( A ) 1\mathrm{SS}(A) 1SS(A),换句话说,对于 i ∈ [ N ] i \in[\mathrm{N}] i[N],有 L 0 : T , 0 : T = 1 S S ( A 0 : T ) L_{0:\mathrm{T}, 0:\mathrm{T}}=1\mathrm{SS}\left(A_{0:\mathrm{T}}\right) L0:T,0:T=1SS(A0:T)。这个算法包含三个步骤,对应于(2):

(i) 通过输入矩阵 B B B扩展输入 X X X(8a),
(ii) 展开独立的标量SSM递归(8b),
(iii) 通过输出矩阵 C C C收缩隐藏状态 H H H(8c)。

请注意,我们在步骤(8b)中使用了标量SSM和1-SS矩阵之间的等价性。

备注3. 我们注意到(8)是Mamba(S6)模型的一个特例。然而,由于扩展的张量 Z Z Z H H H的大小为(T,P,N),因此朴素的实现会很慢;Gu和Dao(2023)引入了一种硬件感知的实现方式,以避免实际创建这些张量。

令人惊讶的是,定理3.5和命题3.6立即表明所有SSM(状态空间模型)都具有与算法(8)相同的渐近效率。

定理3.7. 任何状态大小为 N \mathrm{N} N、序列长度为 T \mathrm{T} T的状态空间模型(定义2.2)都可以在时间 O ( T N ) O(\mathrm{TN}) O(TN)内计算(不考虑潜在的预处理)。

我们注意到,这个结果在结构化SSM文献中是新的。特别是,给定密集的未结构化 A t A_{t} At矩阵,仅总表示本身的大小就似乎是 O ( T N 2 ) O\left(\mathrm{TN}^{2}\right) O(TN2)。因此,定理3.7给出了一个非平凡的结果,即通过预处理步骤,即使未结构化的SSM也可以以最优效率计算,其上限与由 B B B C C C的大小给出的下限 O ( T N ) O(\mathrm{TN}) O(TN)相匹配。

备注4. 鉴于几乎所有在 R ( N , N ) \mathbb{R}^{(\mathbb{N}, \mathbb{N})} R(N,N)上的密集矩阵都可以在 C \mathbb{C} C上对角化这一事实,定理3.7可能并不令人太惊讶,这导致了几乎所有密集的实值SSM(状态空间模型)都等价于一个对角复值SSM的结果。这一事实是为什么对角SSM是结构化SSM最流行形式的原因(Gu, Gupta等人,2022;Gupta, Gu, 和 Berant,2022;Ұ. T. Smith, Warrington, 和 Linderman,2023)。然而,定理3.7为所有实值SSM(不仅仅是可对角化的SSM),以及在其他域(包括 C \mathbb{C} C本身)上的密集SSM提供了更为强大的结果。

在实践中,可高效计算的SSM仍然需要 A A A上的额外结构,特别是为了避免昂贵的预处理步骤(该步骤具有 N N N阶额外的浮点运算,并且涉及硬件效率低下的操作,如奇异值分解)。这些结构是过去关于结构化SSM(如S4(D)和Mamba)的工作以及我们新算法的重点。特别是,当对 A A A施加稍微更强的结构时,我们将在第6节中通过SSM矩阵 M = SSS ⁡ ( A , B , C ) M=\operatorname{SSS}(A, B, C) M=SSS(A,B,C)的块分解来设计非常硬件高效的算法。

3.4.2、二次(朴素)模式

我们注意到,我们的新矩阵视角揭示了计算SSM的另一种方法。朴素地计算SSM的矩阵表示(3)涉及简单地实现序列变换矩阵 M = SSS ⁡ ( A , B , C ) M=\operatorname{SSS}(A, B, C) M=SSS(A,B,C)。这是一个 ( T , T ) (\mathrm{T}, \mathrm{T}) (T,T)矩阵,因此这种朴素的算法在计算规模上将随着序列长度的二次方增长。然而,当序列长度 T T T较短时,由于常数因子和计算模式的硬件友好性(例如,利用矩阵-矩阵乘法),这实际上可能比线性算法更高效。事实上,对于结构化SSM的特定情况,这看起来与二次注意力计算非常相似(见第5节)。

3.4.3、总结

许多序列模型明确地以矩阵序列变换为动机或定义,最显著的是Transformer,其中的矩阵混合器是注意力矩阵。另一方面,RNN和SSM之前并没有以这种方式描述。通过提供状态空间模型的显式矩阵变换形式,我们揭示了理解和使用它们的新方法。从计算的角度来看,任何计算状态空间模型前向传播的方法都可以视为半可分离矩阵上的矩阵乘法算法。半可分离矩阵的视角为状态空间对偶性(SSD)提供了一个视角,其中对偶模式分别指线性时间半可分离矩阵乘法算法和二次时间朴素矩阵乘法。

此外,利用半可分离矩阵的丰富结构可以导出更好的算法和更多的见解(例如第6节和附录B)。在附录C.1中,我们描述了半可分离矩阵的一些附加属性。
4 结构化掩码注意力:通过结构化矩阵泛化线性注意力

在本节中,我们将从头开始重新审视线性注意力框架。本节的主要成果是一个基于张量收缩的简单证明来阐述线性注意力(命题4.1),以及我们在定义4.2中对结构化掩码注意力的泛化抽象。我们注意到,这一节从与状态空间模型不同的方向推导出了主要对偶性结果,并且可以与第3节完全独立地阅读。

  • 第4.1节为我们的注意力变体设置了框架,特别关注核注意力和掩码核注意力。
  • 第4.2节提供了我们的第一个主要注意力结果,即通过张量收缩的视角简单证明线性注意力。
  • 第4.3节定义了结构化掩码注意力,这是我们通过结构化矩阵对先前注意力变体的泛化。

4.1、注意力框架

4.1.1、注意力

(单头)注意力的基本形式是一个将三个向量序列 ( Q , K , V ) (Q, K, V) (Q,K,V)映射到 Y Y Y的函数。

Q = input ( T , N ) K = input ( S , N ) V = input ( S , P ) G = Q K ⊤ ( T , S ) M = f ( G ) ( T , S ) Y = G V ( T , P ) (9) \begin{aligned} Q &= \text{input} & (\mathrm{T}, \mathrm{N}) \\ K &= \text{input} & (\mathrm{S}, \mathrm{N}) \\ V &= \text{input} & (\mathrm{S}, \mathrm{P}) \\ G &= Q K^{\top} & (\mathrm{T}, \mathrm{S}) \\ M &= f(G) & (\mathrm{T}, \mathrm{S}) \\ Y &= G V & (\mathrm{T}, \mathrm{P}) \end{aligned} \tag{9} QKVGMY=input=input=input=QK=f(G)=GV(T,N)(S,N)(S,P)(T,S)(T,S)(T,P)(9)

我们使用“形状注解”来表示张量的维度,例如 Q ∈ R ( T , N ) Q \in \mathbb{R}^{(\mathrm{T}, \mathrm{N})} QR(T,N)。在这个一般形式中, S \mathrm{S} S T \mathrm{T} T分别代表源序列和目标序列的长度, N \mathrm{N} N代表特征维度,而 P \mathrm{P} P代表头维度。

最常见的softmax注意力变体使用softmax激活函数 f = softmax f = \text{softmax} f=softmax来标准化 G G G矩阵的行。

4.1.2、自注意力

我们的讨论基于自注意力的重要情况,其中

(i) 源序列和目标序列是相同的(即 S = T \mathrm{S}=\mathrm{T} S=T),
(ii) 通常特征维度和头维度是相同的(即 N = P \mathrm{N}=\mathrm{P} N=P),
(iii) Q , K , V Q, K, V Q,K,V 是通过对相同的输入向量 X X X 进行线性投影生成的(即 Q = W Q ⋅ X , K = W K ⋅ X , V = W V ⋅ X Q=W_{Q} \cdot X, K=W_{K} \cdot X, V=W_{V} \cdot X Q=WQX,K=WKX,V=WVX)。

然而,我们的讨论抽象了这些选择,并从 Q , K , V Q, K, V Q,K,V 矩阵开始。

备注5:我们的重点是自注意力情况,其中头维度和特征维度相等(即 S = T \mathrm{S}=\mathrm{T} S=T N = P \mathrm{N}=\mathrm{P} N=P),这应该作为运行示例使用。我们定义注意力的一般形式不仅是因为我们的框架捕获了诸如交叉注意力这样的变体,还因为分离维度符号(例如 S 和 T \mathrm{T} T)使得本节中我们主要结果的收缩符号证明更加清晰。

备注6:虽然注意力通常被描述为对这三个输入 Q , K , V Q, K, V Q,K,V 进行的操作,并且这三个输入被视为对称的,但 (9) 中的输入和输出维度表明了并非如此。特别是,输出中不存在特征维度 N \mathrm{N} N;因此,在 S = T \mathrm{S}=\mathrm{T} S=T(例如自注意力)的情况下,我们视 V V V 为主要输入,这样 (9) 定义了一个适当的序列转换 V ↦ Y V \mapsto Y VY(定义 2.1)。

4.1.3、核注意力

将softmax函数应用于Gram矩阵 G G G的步骤可以分解为两个部分:

  1. G G G矩阵进行指数运算。
  2. S S S轴上对 G G G矩阵进行归一化。

我们现在可以忽略归一化项,因为它仅仅相当于将 V = 1 V=1 V=1传入并除以相应的值(我们将在第7.3节中重新讨论这一点)。指数项可以被视为一个核变换:存在一个(无限维)特征映射 φ \varphi φ,使得 exp ⁡ ( Q K ⊤ ) = φ ( Q ) φ ( K ) ⊤ \exp \left(Q K^{\top}\right)=\varphi(Q) \varphi(K)^{\top} exp(QK)=φ(Q)φ(K)。通过将特征映射抽象到 Q Q Q K K K的定义中(即定义 Q , K Q, K Q,K为后变换的版本),我们可以忽略softmax变换,并假设 Q , K Q, K Q,K是由核特征映射任意生成的,并且可能 N ≠ P \mathrm{N} \neq \mathrm{P} N=P

已经提出了许多核注意力的实例化,包括:

  • 原始线性注意力(Linear Attention,Katharopoulos等人2020)将核特征映射定义为任意逐点激活函数,例如 x ↦ 1 + elu ⁡ ( x ) x \mapsto 1+\operatorname{elu}(x) x1+elu(x)
  • 随机特征注意力(Random Feature Attention,RFA)(H. Peng等人2021)选择核特征映射来近似softmax注意力(即exp特征映射),它使用了高斯核的随机傅里叶特征近似(Rahimi和Recht 2007)。这涉及随机投影(即 Q Q Q K K K乘以一个随机投影矩阵 W W W并应用激活函数 x ↦ ( cos ⁡ ( x ) , sin ⁡ ( x ) ) x \mapsto (\cos(x), \sin(x)) x(cos(x),sin(x)))。
  • Performer(Choromanski等人2021)提出了通过正交随机特征(FAVOR +)的快速注意力。其中的正随机特征(PRF)部分选择了核特征映射为随机投影后跟随特征映射 x ↦ 2 − 1 / 2 ( exp ⁡ ( x ) , exp ⁡ ( − x ) ) x \mapsto 2^{-1/2}(\exp(x), \exp(-x)) x21/2(exp(x),exp(x))。这种选择是为了确保核元素是正值,并且可证明地近似softmax注意力。[它还提出了在正交方向上选择随机投影,但我们不考虑这一点。]
  • cosFormer(Qin, Weixuan Sun等人2022)在RFA的基础上增加了余弦重加权机制,该机制结合了位置信息来强调局部性。这实际上是通过特征映射 x ↦ ( x cos ⁡ ( π t / 2 T ) , sin ⁡ ( π t / 2 T ) ) x \mapsto (x \cos(\pi t / 2T), \sin(\pi t / 2T)) x(xcos(πt/2T),sin(πt/2T))来传递 Q t , K t Q_t, K_t Qt,Kt
  • 线性随机注意力(Linear Randomized Attention,Zheng, C. Wang, 和 Kong 2022)从重要性采样的角度泛化了RFA,并将其推广以提供对完整softmax核的更好估计(而不仅仅是exp变换的分子部分)。

4.1.4、掩码(核)注意力

L L L是一个形状为( T , S \mathrm{T}, \mathrm{S} T,S)的掩码。最常见的情况是,在自回归自注意力情况下,当 S = T \mathrm{S}=\mathrm{T} S=T时, L L L可能是一个由1组成的下三角矩阵,表示因果掩码。除了强制因果性外,还可以应用许多其他类型的掩码——特别是各种稀疏模式,如带状、扩张或块对角,它们的目的是降低密集注意力的复杂性。

掩码注意力通常用矩阵符号表示为

KaTeX parse error: \tag works only in display equations

更具体地说,通过形状注释和将其分解为精确的计算序列:

G = Q K ⊤ ( T , S ) M = G ∘ L ( T , S ) Y = M V ( T , P ) (11) \begin{array}{rlrl} G & = QK^{\top} & & (\mathrm{T}, \mathrm{S}) \\ M & = G \circ L & (\mathrm{~T}, \mathrm{~S}) \\ Y & = MV & (\mathrm{~T}, \mathrm{P}) \end{array} \tag{11} GMY=QK=GL=MV( T, S)( T,P)(T,S)(11)

本节中我们对注意力变种的改进推导始于注意到这个公式可以写成一个单一的收缩操作:

Y = contract(TN, SN, SP, TS → TP) ( Q , K , V , L ) (12) Y = \text{contract(TN, SN, SP, TS} \rightarrow \text{TP)}(Q, K, V, L)\tag{12} Y=contract(TN, SN, SP, TSTP)(Q,K,V,L)(12)

并且(11)中的算法可以通过特定顺序的成对收缩操作重新表述为计算(12)

G = contract ⁡ ( T N , S N → T S ) ( Q , K ) ( T , S ) M = contract ⁡ ( T S , T S → T S ) ( G , L ) ( T , S ) Y = contract ⁡ ( T S , S P → T P ) ( M , V ) ( T , P ) (13) \begin{array}{rlr} G & = \operatorname{contract}(\mathrm{TN}, \mathrm{SN} \rightarrow \mathrm{TS})(Q, K) & (\mathrm{T}, \mathrm{S}) \\ M & = \operatorname{contract}(\mathrm{TS}, \mathrm{TS} \rightarrow \mathrm{TS})(G, L) & (\mathrm{T}, \mathrm{S}) \\ Y & = \operatorname{contract}(\mathrm{TS}, \mathrm{SP} \rightarrow \mathrm{TP})(M, V) & (\mathrm{T}, \mathrm{P}) \end{array} \tag{13} GMY=contract(TN,SNTS)(Q,K)=contract(TS,TSTS)(G,L)=contract(TS,SPTP)(M,V)(T,S)(T,S)(T,P)(13)

4.2、线性注意力

线性注意力以及许多其他高效的注意力变体通常是通过改变核心注意力计算中矩阵结合律的顺序来实现的,即 ( Q K ⊤ ) V = Q ( K ⊤ V ) (QK^{\top})V = Q(K^{\top}V) (QK)V=Q(KV)。但是,当添加掩码时,推导过程就不那么直接了(例如,原始论文(Katharopoulos et al. 2020)及其变体(Y. Sun et al. 2023)给出了公式但没有证明)。

大致上,线性注意力方法声称以下公式与(10)等价,这需要通过仔细展开求和并跟踪索引来验证。

Y = Q ⋅ cumsum ( K ⊤ V ) (14) Y = Q \cdot \text{cumsum}(K^{\top}V) \tag{14} Y=Qcumsum(KV)(14)

命题 4.1(Katharopoulos et al. 2020)。自回归核注意力,即带有因果掩码的掩码核注意力,可以通过每次步骤花费常量时间的递归来在 O ( T ) O(T) O(T) 时间内计算。

4.2.1、线性注意力的张量收缩证明

我们给出了线性注意力的一个简单且严谨的推导,这也将立即揭示如何推广它。主要思想是以另一种顺序执行收缩(12)。我们避免使用模棱两可的矩阵符号,而是直接使用收缩符号:

Z = contract(SP, SN → SPN ) ( V , K ) ( S , P , N ) H = contract ( T S , S P N → TPN ) ( L , Z ) ( T , P , N ) Y = contract ( T N , T P N → TP ) ( Q , H ) ( T , P ) (15) \begin{array}{rlr} Z & = \text{contract(SP, SN} \rightarrow \text{SPN})(V, K) & (\mathrm{S}, \mathrm{P}, \mathrm{N}) \\ H & = \text{contract}(\mathrm{TS}, \mathrm{SPN} \rightarrow \text{TPN})(L, Z) & (\mathrm{T}, \mathrm{P}, \mathrm{N}) \\ Y & = \text{contract}(\mathrm{TN}, \mathrm{TPN} \rightarrow \text{TP})(Q, H) & (\mathrm{T}, \mathrm{P}) \end{array} \tag{15} ZHY=contract(SP, SNSPN)(V,K)=contract(TS,SPNTPN)(L,Z)=contract(TN,TPNTP)(Q,H)(S,P,N)(T,P,N)(T,P)(15)

直观上,我们这样解释这种收缩顺序。
第一步(15a)通过特征维度 N \mathrm{N} N 的一个因子来执行“扩展”到更多的特征。第三步(15c)将扩展的特征维度收缩回去。如果将 K K K 视为输入(备注 6),那么 V V V Q Q Q 分别执行扩展和收缩。

第二步是最关键的,它解释了线性注意力中的“线性”部分。首先注意到(15b)只是一个由 L L L进行的直接矩阵乘法(因为( P , N \mathrm{P}, \mathrm{N} P,N)轴可以被展平)。同时注意到这是唯一同时涉及 T \mathrm{T} T S \mathrm{S} S轴的项,因此应该有 Ω ( T S ) \Omega(TS) Ω(TS)的复杂度(即序列长度的二次方)。然而,当掩码 L L L是标准的因果注意力掩码(下三角全为1)时, L L L的矩阵-向量乘法等同于特征维度的累积和。

y = [ 1 ⋮ ⋱ 1 … 1 ] x ⟺ y 0 = x 0 y t = y t − 1 + x t y=\left[\begin{array}{ccc} 1 & & \\ \vdots & \ddots & \\ 1 & \ldots & 1 \end{array}\right] x \Longleftrightarrow \begin{array}{l} y_{0}=x_{0} \\ y_{t}=y_{t-1}+x_{t} \end{array} y= 111 xy0=x0yt=yt1+xt

4.3、结构化掩码注意力

从掩码注意力的张量收缩视角(15)来看,我们可以立即看到原始线性注意力的关键在于,通过因果掩码的矩阵-向量乘法等同于累积和运算符。

然而,我们观察到,注意力掩码并不一定要全部为1。线性注意力快速运行的必要条件是 L L L是一个结构化矩阵,按定义,结构化矩阵是那些具有快速矩阵乘法的矩阵(第2.3节)。特别是,我们可以使用任何具有次二次(理想情况下是线性)矩阵-向量乘法的掩码矩阵 L L L,这会通过加速瓶颈方程(15b)来使复杂度与标准线性注意力相同。

定义 4.2. 结构化掩码注意力(SMA)(或简称结构化注意力)定义为对查询/键/值 Q , K , V Q, K, V Q,K,V以及任何结构化矩阵 L L L(即具有次二次矩阵乘法)的函数,通过4路张量收缩

Y = contract ( T N , S N , S P , T S → T P ) ( Q , K , V , L ) Y=\text{contract}(\mathrm{TN}, \mathrm{SN}, \mathrm{SP}, \mathrm{TS} \rightarrow \mathrm{TP})(Q, K, V, L) Y=contract(TN,SN,SP,TSTP)(Q,K,V,L)

SMA二次模式算法是由(13)定义的成对收缩序列,它对应于标准的(掩码)注意力计算。

SMA线性模式算法是由(15)定义的成对收缩序列,其中步骤(15b)通过次二次结构化矩阵乘法进行了优化。
在这里插入图片描述

我们可以将结构化掩码注意力实例化到任何给定的矩阵结构类别。一些例子包括(图3):

  • 线性注意力使用因果掩码。
  • RetNet(Y. Sun等人,2023年)使用衰减掩码 L i j = γ i − j ⋅ I [ j ≥ i ] L_{ij} = \gamma^{i-j} \cdot \mathbb{I}[j \geq i] Lij=γijI[ji],其中 γ ∈ [ 0 , 1 ] \gamma \in [0,1] γ[0,1]是某个衰减因子。
  • 衰减掩码可以推广为Toeplitz矩阵 L i j = α i − j L_{ij} = \alpha_{i-j} Lij=αij,其中 α ∈ R ⊤ \alpha \in \mathbb{R}^{\top} αR是一组可学习的(或依赖于输入的)参数。这可以解释为一种相对位置编码的形式,类似于AliBi(Press, N. Smith, 和 Lewis 2022年)等方法,但这里是乘法而非加法。
  • 另一种变体可以使用傅里叶矩阵 L i j = ω i j / T L_{ij} = \omega^{ij/T} Lij=ωij/T来以不同的方式编码位置结构。

在第5节中,我们考虑了半可分SMA,它定义了我们主要的SSD模型。
4.3.1 总结:掩码注意力的双重形式

标准(掩码核)注意力经常混淆于函数和算法之间。区分这种差异为理解注意力的不同变体提供了一种清晰的方式。

  • 我们将掩码注意力视为一个特定的函数(12)。
  • 标准二次注意力计算(13)可以被视为计算这个函数的算法。
  • 线性注意力(15)是计算相同函数的另一种算法。

此外,在这种情况下

  • 掩码注意力函数仅仅是四个项上的特定收缩。
  • 二次注意力和线性注意力算法仅仅是执行收缩的两种不同顺序。

已知收缩顺序可以对计算复杂度产生重大影响,导致了二次与线性的分裂。正如状态空间模型是可以通过多种方式计算的转换,具有双重的二次与线性形式(第3.4节),线性注意力也有类似的双重性,这种双重性来自于两种不同的收缩顺序。事实上,这些最终是对相同基础双重性的不同视角,我们将在第5节中明确这一点。

5、状态空间对偶性

在第3节和第4节中,我们定义了结构化状态空间模型和结构化注意力,讨论了它们的性质,并展示了它们都有二次算法和线性算法。本节将它们联系起来。我们的主要结果是展示结构化状态空间模型的特定情况与结构化注意力的特定情况相吻合,并且线性时间SSM算法和二次时间核注意力算法是彼此的对偶形式。

  • 第5.1节将状态空间模型专门化为标量结构,其中朴素的二次计算可以看作是核注意力的一个实例。
  • 第5.2节将结构化掩码注意力专门化为半可分SMA,它以有效自回归的方式描述了掩码注意力。
  • 第5.3节总结了结构化掩码注意力和结构化状态空间模型之间的联系,称为结构化状态空间对偶性。

5.1、标量-恒等结构化状态空间模型

在第3节中,我们展示了状态空间模型等价于半可分矩阵变换,从而得到了线性递归形式和二次朴素形式。
回想一下,SSM(状态空间模型)被定义为 y = SSM ( A , B , C ) ( x ) y = \text{SSM}(A, B, C)(x) y=SSM(A,B,C)(x),并且SSM的矩阵形式使用了SSS(连续半可分)表示 M = SSS ( A , B , C ) M = \text{SSS}(A, B, C) M=SSS(A,B,C),其中 M j i = C j ⊤ A j : i B i M_{ji} = C_j^\top A_{j:i} B_i Mji=CjAj:iBi(方程(3))。

现在让我们考虑一个特殊情况,即 A j A_j Aj 只是一个标量;换句话说,这是一个结构化SSM的实例化,其中A矩阵是极其结构化的: A = a I A = aI A=aI 对于标量 a a a 和单位矩阵 I I I。那么我们可以重新排列

M j i = A j : i ⋅ ( C j ⊤ B i ) M_{ji} = A_{j:i} \cdot (C_j^\top B_i) Mji=Aj:i(CjBi)

这可以矢量化为

L : = 1 SS ( a ) M = L ∘ ( C B ⊤ ) \begin{aligned} L & := 1 \text{SS}(a) \\ M & = L \circ (CB^\top) \end{aligned} LM:=1SS(a)=L(CB)

其中 B , C ∈ R ( T , N ) B, C \in \mathbb{R}^{(T, N)} B,CR(T,N)

使用这种表述,完整的输出 Y = M X Y = MX Y=MX 精确计算为

G = contract(TN, SN → TS ) ( C , B ) ( T , S ) M = contract(TS, TS → TS ) ( G , L ) ( T , S ) Y = contract(TS, SP → TP ) ( M , X ) ( T , P ) (16) \begin{aligned} G & = \text{contract(TN, SN} \rightarrow \text{TS})(C, B) && (T, S) \\ M & = \text{contract(TS, TS} \rightarrow \text{TS})(G, L) && (T, S) \\ Y & = \text{contract(TS, SP} \rightarrow \text{TP})(M, X) && (T, P) \end{aligned} \tag{16} GMY=contract(TN, SNTS)(C,B)=contract(TS, TSTS)(G,L)=contract(TS, SPTP)(M,X)(T,S)(T,S)(T,P)(16)

其中 S = T S = T S=T。但这与原始掩码核注意力定义(13)完全相同!

因此,正如第3.4节所暗示的,通过具体化半可分矩阵 M M M 并执行二次矩阵-向量乘法来朴素地计算标量结构化SSM,与二次掩码核注意力完全相同。

5.2、1-半可分结构化掩码注意力

结构化掩码注意力允许使用任何结构化掩码 L L L。当 L L L是因果掩码时,它就是标准的线性注意力。注意,因果掩码是 L = SS ( 1 T ) L=\text{SS}(1_T) L=SS(1T),即1-SS掩码是由定义(6)中的 a t = 1 a_t=1 at=1生成的。这促使我们将 L L L推广到1-半可分掩码的类别,或1-半可分结构化掩码注意力(1-SS SMA),其中线性注意力的递归中的累加和(cumsum)被更一般的递归所替代——即标量SSM扫描,即1-半可分矩阵乘法(第3.2.2节)。

最后,我们考虑1-半可分SMA的最重要原因是因为计算它的线性形式是对角状态空间模型的一个特例。SMA的线性形式是算法(15),其中瓶颈步骤(15b)可以视为1-SS掩码的矩阵乘法。在第3节中,我们还写出了对角SSM(8)的计算过程,其中瓶颈步骤(8b)是标量SSM递归,它等价于1-SS乘法。唯一的区别是(8b)在 L L L中有一个额外的 N \mathrm{N} N维度,因为矩阵 A A A是一个大小为 N \mathrm{N} N的对角矩阵。如果 A A A的所有对角元素都相同,这个 N \mathrm{N} N维度就会消失,这导致了推论5.1。

推论5.1. 1-SS SMA(使用1-半可分结构化矩阵 L L L的掩码注意力)(15)是当对角矩阵是单位矩阵的标量倍数时,对角SSM(8)的一个特例。
虽然推论5.1指出1-SS SMA具有高效的递归形式,我们也可以展示一个相反的结果,即描述哪些SMA实例具有高效的自回归性。

定理5.2. 对于任何具有有界阶数的自回归过程的结构化掩码注意力(定义4.2)的实例化,结构化掩码 L L L必须是半可分矩阵。
换句话说,高效的自回归注意力是一般半可分SMA。定理5.2在附录C.2中得到了证明。

备注7. 尽管1-半可分SMA是状态空间模型的一个特例,但一般的半可分SMA比1-SS SMA具有更严格的表达能力,并且不能用标准的SSM来描述。然而, L L L的半可分乘法和SMA的线性形式(方程(15a))都涉及扩展和收缩步骤,并且可以被吸收到具有单一(更大)扩展的类似1-SS SMA实例中。

总之,1-半可分结构化注意力是SMA最重要的一个情况,因为它:

  • 是具有输入依赖递归的线性注意力的自然推广。
  • 是一般半可分注意力的最简单情况,它等同于高效的自回归注意力。
  • 是对角状态空间模型的一个特例。
    5.3 结构化状态空间对偶性(SSD)

总结我们的结果:

  • 结构化状态空间模型(第3节)通常是通过线性时间递归来定义的。然而,通过扩展描述其线性序列到序列转换的矩阵形式,可以推导出二次形式。
  • 注意力变体(第4节)是通过二次时间成对交互来定义的。但是,通过将其视为四路张量收缩并以不同的顺序进行约简,可以推导出线性形式。
  • 这两者中每个的自然特例——更具体地说, A A A矩阵具有标量恒等结构的状态空间模型,以及 L L L掩码具有1-半可分结构的结构化掩码注意力——在完全相同的线性和二次形式下互为对偶。

图4总结了这两种表示之间的对偶性。
扩展的相关工作和讨论(第10节)更详细地描述了SSD与一般SSM/注意力之间的关系。
在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com